mirror of https://github.com/rwf2/Rocket.git
Update to hyper 1. Enable custom + unix listeners.
This commit completely rewrites Rocket's HTTP serving. In addition to significant internal cleanup, this commit introduces the following major features: * Support for custom, external listeners in the `listener` module. The new `listener` module contains new `Bindable`, `Listener`, and `Connection` traits which enable composable, external implementations of connection listeners. Rocket can launch on any `Listener`, or anything that can be used to create a listener (`Bindable`), via a new `launch_on()` method. * Support for Unix domain socket listeners out of the box. The default listener backwards compatibly supports listening on Unix domain sockets. To do so, configure an `address` of `unix:path/to/socket` and optional set `reuse` to `true` (the default) or `false` which controls whether Rocket will handle creating and deleting the unix domain socket. In addition to these new features, this commit makes the following major improvements: * Rocket now depends on hyper 1. * Rocket no longer depends on hyper to handle connections. This allows us to handle more connection failure conditions which results in an overall more robust server with fewer dependencies. * Logic to work around hyper's inability to reference incoming request data in the response results in a 15% performance improvement. * `Client`s can be marked secure with `Client::{un}tracked_secure()`, allowing Rocket to treat local connections as running under TLS. * The `macros` feature of `tokio` is no longer used by Rocket itself. Dependencies can take advantage of this reduction in compile-time cost by disabling the new default feature `tokio-macros`. * A new `TlsConfig::validate()` method allows checking a TLS config. * New `TlsConfig::{certs,key}_reader()`, `MtlsConfig::ca_certs_reader()` methods return `BufReader`s, which allow reading the configured certs and key directly. * A new `NamedFile::open_with()` constructor allows specifying `OpenOptions`. These improvements resulted in the following breaking changes: * The MSRV is now 1.74. * `hyper` is no longer exported from `rocket::http`. * `IoHandler::io` takes `Box<Self>` instead of `Pin<Box<Self>>`. - Use `Box::into_pin(self)` to recover the previous type. * `Response::upgrade()` now returns an `&mut dyn IoHandler`, not `Pin<& mut _>`. * `Config::{address,port,tls,mtls}` methods have been removed. - Use methods on `Rocket::endpoint()` instead. * `TlsConfig` was moved to `tls::TlsConfig`. * `MutualTls` was renamed and moved to `mtls::MtlsConfig`. * `ErrorKind::TlsBind` was removed. * The second field of `ErrorKind::Shutdown` was removed. * `{Local}Request::{set_}remote()` methods take/return an `Endpoint`. * `Client::new()` was removed; it was previously deprecated. Internally, the following major changes were made: * A new `async_bound` attribute macro was introduced to allow setting bounds on futures returned by `async fn`s in traits while maintaining good docs. * All utility functionality was moved to a new `util` module. Resolves #2671. Resolves #1070.
This commit is contained in:
parent
e9b568d9b2
commit
fd294049c7
|
@ -33,7 +33,6 @@ use crate::result::{Result, Error};
|
|||
///
|
||||
/// [`StreamExt`]: rocket::futures::StreamExt
|
||||
/// [`SinkExt`]: rocket::futures::SinkExt
|
||||
|
||||
pub struct DuplexStream(tokio_tungstenite::WebSocketStream<IoStream>);
|
||||
|
||||
impl DuplexStream {
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use std::io;
|
||||
use std::pin::Pin;
|
||||
|
||||
use rocket::data::{IoHandler, IoStream};
|
||||
use rocket::futures::{self, StreamExt, SinkExt, future::BoxFuture, stream::SplitStream};
|
||||
|
@ -37,10 +36,6 @@ pub struct WebSocket {
|
|||
}
|
||||
|
||||
impl WebSocket {
|
||||
fn new(key: String) -> WebSocket {
|
||||
WebSocket { config: Config::default(), key }
|
||||
}
|
||||
|
||||
/// Change the default connection configuration to `config`.
|
||||
///
|
||||
/// # Example
|
||||
|
@ -202,7 +197,9 @@ impl<'r> FromRequest<'r> for WebSocket {
|
|||
let is_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13");
|
||||
let key = headers.get_one("Sec-WebSocket-Key").map(|k| derive_accept_key(k.as_bytes()));
|
||||
match key {
|
||||
Some(key) if is_upgrade && is_ws && is_13 => Outcome::Success(WebSocket::new(key)),
|
||||
Some(key) if is_upgrade && is_ws && is_13 => {
|
||||
Outcome::Success(WebSocket { key, config: Config::default() })
|
||||
},
|
||||
Some(_) | None => Outcome::Forward(Status::BadRequest)
|
||||
}
|
||||
}
|
||||
|
@ -232,9 +229,9 @@ impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S>
|
|||
|
||||
#[rocket::async_trait]
|
||||
impl IoHandler for Channel<'_> {
|
||||
async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
|
||||
let channel = Pin::into_inner(self);
|
||||
let result = (channel.handler)(DuplexStream::new(io, channel.ws.config).await).await;
|
||||
async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
|
||||
let stream = DuplexStream::new(io, self.ws.config).await;
|
||||
let result = (self.handler)(stream).await;
|
||||
handle_result(result).map(|_| ())
|
||||
}
|
||||
}
|
||||
|
@ -243,9 +240,9 @@ impl IoHandler for Channel<'_> {
|
|||
impl<'r, S> IoHandler for MessageStream<'r, S>
|
||||
where S: futures::Stream<Item = Result<Message>> + Send + 'r
|
||||
{
|
||||
async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
|
||||
async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
|
||||
let (mut sink, source) = DuplexStream::new(io, self.ws.config).await.split();
|
||||
let stream = (Pin::into_inner(self).handler)(source);
|
||||
let stream = (self.handler)(source);
|
||||
rocket::tokio::pin!(stream);
|
||||
while let Some(msg) = stream.next().await {
|
||||
let result = match msg {
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
use proc_macro2::{TokenStream, Span};
|
||||
use devise::{Spanned, Result, ext::SpanDiagnosticExt};
|
||||
use syn::{Token, parse_quote, parse_quote_spanned};
|
||||
use syn::{TraitItemFn, TypeParamBound, ReturnType, Attribute};
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::parse::Parser;
|
||||
|
||||
fn _async_bound(
|
||||
args: proc_macro::TokenStream,
|
||||
input: proc_macro::TokenStream
|
||||
) -> Result<TokenStream> {
|
||||
let bounds = <Punctuated<TypeParamBound, Token![+]>>::parse_terminated.parse(args)?;
|
||||
if bounds.is_empty() {
|
||||
return Ok(input.into());
|
||||
}
|
||||
|
||||
let mut func: TraitItemFn = syn::parse(input)?;
|
||||
let original: TraitItemFn = func.clone();
|
||||
if !func.sig.asyncness.is_some() {
|
||||
let diag = Span::call_site()
|
||||
.error("attribute can only be applied to async fns")
|
||||
.span_help(func.sig.span(), "this fn declaration must be `async`");
|
||||
|
||||
return Err(diag);
|
||||
}
|
||||
|
||||
let doc: Attribute = parse_quote! {
|
||||
#[doc = concat!(
|
||||
"# Future Bounds",
|
||||
"\n",
|
||||
"**The `Future` generated by this `async fn` must be `", stringify!(#bounds), "`**."
|
||||
)]
|
||||
};
|
||||
|
||||
func.sig.asyncness = None;
|
||||
func.sig.output = match func.sig.output {
|
||||
ReturnType::Type(arrow, ty) => parse_quote_spanned!(ty.span() =>
|
||||
#arrow impl ::core::future::Future<Output = #ty> + #bounds
|
||||
),
|
||||
default@ReturnType::Default => default
|
||||
};
|
||||
|
||||
Ok(quote! {
|
||||
#[cfg(all(not(doc), rust_analyzer))]
|
||||
#original
|
||||
|
||||
#[cfg(all(doc, not(rust_analyzer)))]
|
||||
#doc
|
||||
#original
|
||||
|
||||
#[cfg(not(any(doc, rust_analyzer)))]
|
||||
#func
|
||||
})
|
||||
}
|
||||
|
||||
pub fn async_bound(
|
||||
args: proc_macro::TokenStream,
|
||||
input: proc_macro::TokenStream
|
||||
) -> TokenStream {
|
||||
_async_bound(args, input).unwrap_or_else(|d| d.emit_as_item_tokens())
|
||||
}
|
|
@ -2,3 +2,4 @@ pub mod entry;
|
|||
pub mod catch;
|
||||
pub mod route;
|
||||
pub mod param;
|
||||
pub mod async_bound;
|
||||
|
|
|
@ -331,7 +331,7 @@ fn codegen_route(route: Route) -> Result<TokenStream> {
|
|||
let internal_uri_macro = internal_uri_macro_decl(&route);
|
||||
let responder_outcome = responder_outcome_expr(&route);
|
||||
|
||||
let method = route.attr.method;
|
||||
let method = &route.attr.method;
|
||||
let uri = route.attr.uri.to_string();
|
||||
let rank = Optional(route.attr.rank);
|
||||
let format = Optional(route.attr.format.as_ref());
|
||||
|
|
|
@ -13,7 +13,7 @@ pub struct Status(pub http::Status);
|
|||
#[derive(Debug)]
|
||||
pub struct MediaType(pub http::MediaType);
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Method(pub http::Method);
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
@ -108,7 +108,7 @@ const VALID_METHODS: &[http::Method] = &[
|
|||
impl FromMeta for Method {
|
||||
fn from_meta(meta: &MetaItem) -> Result<Self> {
|
||||
let span = meta.value_span();
|
||||
let help_text = format!("method must be one of: {}", VALID_METHODS_STR);
|
||||
let help_text = format!("method must be one of: {VALID_METHODS_STR}");
|
||||
|
||||
if let MetaItem::Path(path) = meta {
|
||||
if let Some(ident) = path.last_ident() {
|
||||
|
@ -131,19 +131,13 @@ impl FromMeta for Method {
|
|||
|
||||
impl ToTokens for Method {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
let method_tokens = match self.0 {
|
||||
http::Method::Get => quote!(::rocket::http::Method::Get),
|
||||
http::Method::Put => quote!(::rocket::http::Method::Put),
|
||||
http::Method::Post => quote!(::rocket::http::Method::Post),
|
||||
http::Method::Delete => quote!(::rocket::http::Method::Delete),
|
||||
http::Method::Options => quote!(::rocket::http::Method::Options),
|
||||
http::Method::Head => quote!(::rocket::http::Method::Head),
|
||||
http::Method::Trace => quote!(::rocket::http::Method::Trace),
|
||||
http::Method::Connect => quote!(::rocket::http::Method::Connect),
|
||||
http::Method::Patch => quote!(::rocket::http::Method::Patch),
|
||||
};
|
||||
let mut chars = self.0.as_str().chars();
|
||||
let variant_str = chars.next()
|
||||
.map(|c| c.to_ascii_uppercase().to_string() + &chars.as_str().to_lowercase())
|
||||
.unwrap_or_default();
|
||||
|
||||
tokens.extend(method_tokens);
|
||||
let variant = syn::Ident::new(&variant_str, Span::call_site());
|
||||
tokens.extend(quote!(::rocket::http::Method::#variant));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1497,3 +1497,10 @@ pub fn internal_guide_tests(input: TokenStream) -> TokenStream {
|
|||
pub fn export(input: TokenStream) -> TokenStream {
|
||||
emit!(bang::export_internal(input))
|
||||
}
|
||||
|
||||
/// Private Rocket attribute: `async_bound(Bounds + On + Returned + Future)`.
|
||||
#[doc(hidden)]
|
||||
#[proc_macro_attribute]
|
||||
pub fn async_bound(args: TokenStream, input: TokenStream) -> TokenStream {
|
||||
emit!(attribute::async_bound::async_bound(args, input))
|
||||
}
|
||||
|
|
|
@ -17,43 +17,22 @@ rust-version = "1.64"
|
|||
|
||||
[features]
|
||||
default = []
|
||||
tls = ["rustls", "tokio-rustls", "rustls-pemfile"]
|
||||
mtls = ["tls", "x509-parser"]
|
||||
http2 = ["hyper/http2"]
|
||||
private-cookies = ["cookie/private", "cookie/key-expansion"]
|
||||
serde = ["uncased/with-serde-alloc", "serde_"]
|
||||
uuid = ["uuid_"]
|
||||
|
||||
[dependencies]
|
||||
smallvec = { version = "1.11", features = ["const_generics", "const_new"] }
|
||||
percent-encoding = "2"
|
||||
http = "0.2"
|
||||
time = { version = "0.3", features = ["formatting", "macros"] }
|
||||
indexmap = "2"
|
||||
rustls = { version = "0.22", optional = true }
|
||||
tokio-rustls = { version = "0.25", optional = true }
|
||||
rustls-pemfile = { version = "2.0.0", optional = true }
|
||||
tokio = { version = "1.6.1", features = ["net", "sync", "time"] }
|
||||
log = "0.4"
|
||||
ref-cast = "1.0"
|
||||
uncased = "0.9.6"
|
||||
uncased = "0.9.10"
|
||||
either = "1"
|
||||
pear = "0.2.8"
|
||||
pin-project-lite = "0.2"
|
||||
memchr = "2"
|
||||
stable-pattern = "0.1"
|
||||
cookie = { version = "0.18", features = ["percent-encode"] }
|
||||
state = "0.6"
|
||||
futures = { version = "0.3", default-features = false }
|
||||
|
||||
[dependencies.x509-parser]
|
||||
version = "0.13"
|
||||
optional = true
|
||||
|
||||
[dependencies.hyper]
|
||||
version = "0.14.9"
|
||||
default-features = false
|
||||
features = ["http1", "runtime", "server", "stream"]
|
||||
|
||||
[dependencies.serde_]
|
||||
package = "serde"
|
||||
|
|
|
@ -745,8 +745,7 @@ impl<'h> HeaderMap<'h> {
|
|||
/// WARNING: This is unstable! Do not use this method outside of Rocket!
|
||||
#[doc(hidden)]
|
||||
#[inline]
|
||||
pub fn into_iter_raw(self)
|
||||
-> impl Iterator<Item=(Uncased<'h>, Vec<Cow<'h, str>>)> {
|
||||
pub fn into_iter_raw(self) -> impl Iterator<Item=(Uncased<'h>, Vec<Cow<'h, str>>)> {
|
||||
self.headers.into_iter()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
//! Re-exported hyper HTTP library types.
|
||||
//!
|
||||
//! All types that are re-exported from Hyper reside inside of this module.
|
||||
//! These types will, with certainty, be removed with time, but they reside here
|
||||
//! while necessary.
|
||||
|
||||
pub use hyper::{Method, Error, Body, Uri, Version, Request, Response};
|
||||
pub use hyper::{body, server, service, upgrade};
|
||||
pub use http::{HeaderValue, request, uri};
|
||||
|
||||
/// Reexported Hyper HTTP header types.
|
||||
pub mod header {
|
||||
macro_rules! import_http_headers {
|
||||
($($name:ident),*) => ($(
|
||||
pub use hyper::header::$name as $name;
|
||||
)*)
|
||||
}
|
||||
|
||||
import_http_headers! {
|
||||
ACCEPT, ACCEPT_CHARSET, ACCEPT_ENCODING, ACCEPT_LANGUAGE, ACCEPT_RANGES,
|
||||
ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS,
|
||||
ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE,
|
||||
ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, ALLOW,
|
||||
AUTHORIZATION, CACHE_CONTROL, CONNECTION, CONTENT_DISPOSITION,
|
||||
CONTENT_ENCODING, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_LOCATION,
|
||||
CONTENT_RANGE, CONTENT_SECURITY_POLICY,
|
||||
CONTENT_SECURITY_POLICY_REPORT_ONLY, CONTENT_TYPE, DATE, ETAG, EXPECT,
|
||||
EXPIRES, FORWARDED, FROM, HOST, IF_MATCH, IF_MODIFIED_SINCE,
|
||||
IF_NONE_MATCH, IF_RANGE, IF_UNMODIFIED_SINCE, LAST_MODIFIED, LINK,
|
||||
LOCATION, ORIGIN, PRAGMA, RANGE, REFERER, REFERRER_POLICY, REFRESH,
|
||||
STRICT_TRANSPORT_SECURITY, TE, TRANSFER_ENCODING, UPGRADE, USER_AGENT,
|
||||
VARY
|
||||
}
|
||||
}
|
|
@ -4,15 +4,11 @@
|
|||
//! Types that map to concepts in HTTP.
|
||||
//!
|
||||
//! This module exports types that map to HTTP concepts or to the underlying
|
||||
//! HTTP library when needed. Because the underlying HTTP library is likely to
|
||||
//! change (see [#17]), types in [`hyper`] should be considered unstable.
|
||||
//!
|
||||
//! [#17]: https://github.com/rwf2/Rocket/issues/17
|
||||
//! HTTP library when needed.
|
||||
|
||||
#[macro_use]
|
||||
extern crate pear;
|
||||
|
||||
pub mod hyper;
|
||||
pub mod uri;
|
||||
pub mod ext;
|
||||
|
||||
|
@ -22,7 +18,6 @@ mod method;
|
|||
mod status;
|
||||
mod raw_str;
|
||||
mod parse;
|
||||
mod listener;
|
||||
|
||||
/// Case-preserving, ASCII case-insensitive string types.
|
||||
///
|
||||
|
@ -39,14 +34,8 @@ pub mod uncased {
|
|||
pub mod private {
|
||||
pub use crate::parse::Indexed;
|
||||
pub use smallvec::{SmallVec, Array};
|
||||
pub use crate::listener::{TcpListener, Incoming, Listener, Connection, Certificates};
|
||||
pub use cookie;
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
#[cfg(feature = "tls")]
|
||||
pub mod tls;
|
||||
|
||||
pub use crate::method::Method;
|
||||
pub use crate::status::{Status, StatusClass};
|
||||
pub use crate::raw_str::{RawStr, RawStrBuf};
|
||||
|
|
|
@ -1,257 +0,0 @@
|
|||
use std::fmt;
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use std::sync::Arc;
|
||||
|
||||
use log::warn;
|
||||
use tokio::time::Sleep;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use hyper::server::accept::Accept;
|
||||
use state::InitCell;
|
||||
|
||||
pub use tokio::net::TcpListener;
|
||||
|
||||
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
|
||||
#[cfg(not(feature = "tls"))]
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub struct CertificateDer(pub(crate) Vec<u8>);
|
||||
|
||||
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
|
||||
#[cfg(feature = "tls")]
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
#[repr(transparent)]
|
||||
pub struct CertificateDer(pub(crate) rustls::pki_types::CertificateDer<'static>);
|
||||
|
||||
/// A collection of raw certificate data.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct Certificates(Arc<InitCell<Vec<CertificateDer>>>);
|
||||
|
||||
impl From<Vec<CertificateDer>> for Certificates {
|
||||
fn from(value: Vec<CertificateDer>) -> Self {
|
||||
Certificates(Arc::new(value.into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
impl From<Vec<rustls::pki_types::CertificateDer<'static>>> for Certificates {
|
||||
fn from(value: Vec<rustls::pki_types::CertificateDer<'static>>) -> Self {
|
||||
let value: Vec<_> = value.into_iter().map(CertificateDer).collect();
|
||||
Certificates(Arc::new(value.into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
impl Certificates {
|
||||
/// Set the the raw certificate chain data. Only the first call actually
|
||||
/// sets the data; the remaining do nothing.
|
||||
#[cfg(feature = "tls")]
|
||||
pub(crate) fn set(&self, data: Vec<CertificateDer>) {
|
||||
self.0.set(data);
|
||||
}
|
||||
|
||||
/// Returns the raw certificate chain data, if any is available.
|
||||
pub fn chain_data(&self) -> Option<&[CertificateDer]> {
|
||||
self.0.try_get().map(|v| v.as_slice())
|
||||
}
|
||||
}
|
||||
|
||||
// TODO.async: 'Listener' and 'Connection' provide common enough functionality
|
||||
// that they could be introduced in upstream libraries.
|
||||
/// A 'Listener' yields incoming connections
|
||||
pub trait Listener {
|
||||
/// The connection type returned by this listener.
|
||||
type Connection: Connection;
|
||||
|
||||
/// Return the actual address this listener bound to.
|
||||
fn local_addr(&self) -> Option<SocketAddr>;
|
||||
|
||||
/// Try to accept an incoming Connection if ready. This should only return
|
||||
/// an `Err` when a fatal problem occurs as Hyper kills the server on `Err`.
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<io::Result<Self::Connection>>;
|
||||
}
|
||||
|
||||
/// A 'Connection' represents an open connection to a client
|
||||
pub trait Connection: AsyncRead + AsyncWrite {
|
||||
/// The remote address, i.e. the client's socket address, if it is known.
|
||||
fn peer_address(&self) -> Option<SocketAddr>;
|
||||
|
||||
/// Requests that the connection not delay reading or writing data as much
|
||||
/// as possible. For connections backed by TCP, this corresponds to setting
|
||||
/// `TCP_NODELAY`.
|
||||
fn enable_nodelay(&self) -> io::Result<()>;
|
||||
|
||||
/// DER-encoded X.509 certificate chain presented by the client, if any.
|
||||
///
|
||||
/// The certificate order must be as it appears in the TLS protocol: the
|
||||
/// first certificate relates to the peer, the second certifies the first,
|
||||
/// the third certifies the second, and so on.
|
||||
///
|
||||
/// Defaults to an empty vector to indicate that no certificates were
|
||||
/// presented.
|
||||
fn peer_certificates(&self) -> Option<Certificates> { None }
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
/// This is a generic version of hyper's AddrIncoming that is intended to be
|
||||
/// usable with listeners other than a plain TCP stream, e.g. TLS and/or Unix
|
||||
/// sockets. It does so by bridging the `Listener` trait to what hyper wants (an
|
||||
/// Accept). This type is internal to Rocket.
|
||||
#[must_use = "streams do nothing unless polled"]
|
||||
pub struct Incoming<L> {
|
||||
sleep_on_errors: Option<Duration>,
|
||||
nodelay: bool,
|
||||
#[pin]
|
||||
pending_error_delay: Option<Sleep>,
|
||||
#[pin]
|
||||
listener: L,
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: Listener> Incoming<L> {
|
||||
/// Construct an `Incoming` from an existing `Listener`.
|
||||
pub fn new(listener: L) -> Self {
|
||||
Self {
|
||||
listener,
|
||||
sleep_on_errors: Some(Duration::from_millis(250)),
|
||||
pending_error_delay: None,
|
||||
nodelay: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set whether and how long to sleep on accept errors.
|
||||
///
|
||||
/// A possible scenario is that the process has hit the max open files
|
||||
/// allowed, and so trying to accept a new connection will fail with
|
||||
/// `EMFILE`. In some cases, it's preferable to just wait for some time, if
|
||||
/// the application will likely close some files (or connections), and try
|
||||
/// to accept the connection again. If this option is `true`, the error
|
||||
/// will be logged at the `error` level, since it is still a big deal,
|
||||
/// and then the listener will sleep for 1 second.
|
||||
///
|
||||
/// In other cases, hitting the max open files should be treat similarly
|
||||
/// to being out-of-memory, and simply error (and shutdown). Setting
|
||||
/// this option to `None` will allow that.
|
||||
///
|
||||
/// Default is 1 second.
|
||||
pub fn sleep_on_errors(mut self, val: Option<Duration>) -> Self {
|
||||
self.sleep_on_errors = val;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set whether to request no delay on all incoming connections. The default
|
||||
/// is `false`. See [`Connection::enable_nodelay()`] for details.
|
||||
pub fn nodelay(mut self, nodelay: bool) -> Self {
|
||||
self.nodelay = nodelay;
|
||||
self
|
||||
}
|
||||
|
||||
fn poll_accept_next(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<io::Result<L::Connection>> {
|
||||
/// This function defines per-connection errors: errors that affect only
|
||||
/// a single connection's accept() and don't imply anything about the
|
||||
/// success probability of the next accept(). Thus, we can attempt to
|
||||
/// `accept()` another connection immediately. All other errors will
|
||||
/// incur a delay before the next `accept()` is performed. The delay is
|
||||
/// useful to handle resource exhaustion errors like ENFILE and EMFILE.
|
||||
/// Otherwise, could enter into tight loop.
|
||||
fn is_connection_error(e: &io::Error) -> bool {
|
||||
matches!(e.kind(),
|
||||
| io::ErrorKind::ConnectionRefused
|
||||
| io::ErrorKind::ConnectionAborted
|
||||
| io::ErrorKind::ConnectionReset)
|
||||
}
|
||||
|
||||
let mut this = self.project();
|
||||
loop {
|
||||
// Check if a previous sleep timer is active, set on I/O errors.
|
||||
if let Some(delay) = this.pending_error_delay.as_mut().as_pin_mut() {
|
||||
futures::ready!(delay.poll(cx));
|
||||
}
|
||||
|
||||
this.pending_error_delay.set(None);
|
||||
|
||||
match futures::ready!(this.listener.as_mut().poll_accept(cx)) {
|
||||
Ok(stream) => {
|
||||
if *this.nodelay {
|
||||
if let Err(e) = stream.enable_nodelay() {
|
||||
warn!("failed to enable NODELAY: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
return Poll::Ready(Ok(stream));
|
||||
},
|
||||
Err(e) => {
|
||||
if is_connection_error(&e) {
|
||||
warn!("single connection accept error {}; accepting next now", e);
|
||||
} else if let Some(duration) = this.sleep_on_errors {
|
||||
// We might be able to recover. Try again in a bit.
|
||||
warn!("accept error {}; recovery attempt in {}ms", e, duration.as_millis());
|
||||
this.pending_error_delay.set(Some(tokio::time::sleep(*duration)));
|
||||
} else {
|
||||
return Poll::Ready(Err(e));
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: Listener> Accept for Incoming<L> {
|
||||
type Conn = L::Connection;
|
||||
type Error = io::Error;
|
||||
|
||||
#[inline]
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<Option<io::Result<Self::Conn>>> {
|
||||
self.poll_accept_next(cx).map(Some)
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: fmt::Debug> fmt::Debug for Incoming<L> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Incoming")
|
||||
.field("listener", &self.listener)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Listener for TcpListener {
|
||||
type Connection = TcpStream;
|
||||
|
||||
#[inline]
|
||||
fn local_addr(&self) -> Option<SocketAddr> {
|
||||
self.local_addr().ok()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<io::Result<Self::Connection>> {
|
||||
(*self).poll_accept(cx).map_ok(|(stream, _addr)| stream)
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection for TcpStream {
|
||||
#[inline]
|
||||
fn peer_address(&self) -> Option<SocketAddr> {
|
||||
self.peer_addr().ok()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn enable_nodelay(&self) -> io::Result<()> {
|
||||
self.set_nodelay(true)
|
||||
}
|
||||
}
|
|
@ -3,8 +3,6 @@ use std::str::FromStr;
|
|||
|
||||
use self::Method::*;
|
||||
|
||||
use crate::hyper;
|
||||
|
||||
// TODO: Support non-standard methods, here and in codegen?
|
||||
|
||||
/// Representation of HTTP methods.
|
||||
|
@ -29,6 +27,7 @@ use crate::hyper;
|
|||
/// }
|
||||
/// # }
|
||||
/// ```
|
||||
#[repr(u8)]
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
pub enum Method {
|
||||
/// The `GET` variant.
|
||||
|
@ -52,23 +51,6 @@ pub enum Method {
|
|||
}
|
||||
|
||||
impl Method {
|
||||
/// WARNING: This is unstable! Do not use this method outside of Rocket!
|
||||
#[doc(hidden)]
|
||||
pub fn from_hyp(method: &hyper::Method) -> Option<Method> {
|
||||
match *method {
|
||||
hyper::Method::GET => Some(Get),
|
||||
hyper::Method::PUT => Some(Put),
|
||||
hyper::Method::POST => Some(Post),
|
||||
hyper::Method::DELETE => Some(Delete),
|
||||
hyper::Method::OPTIONS => Some(Options),
|
||||
hyper::Method::HEAD => Some(Head),
|
||||
hyper::Method::TRACE => Some(Trace),
|
||||
hyper::Method::CONNECT => Some(Connect),
|
||||
hyper::Method::PATCH => Some(Patch),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if an HTTP request with the method represented by `self`
|
||||
/// always supports a payload.
|
||||
///
|
||||
|
|
|
@ -1,235 +0,0 @@
|
|||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::future::Future;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_rustls::{Accept, TlsAcceptor, server::TlsStream as BareTlsStream};
|
||||
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};
|
||||
|
||||
use crate::tls::util::{load_cert_chain, load_key, load_ca_certs};
|
||||
use crate::listener::{Connection, Listener, Certificates, CertificateDer};
|
||||
|
||||
/// A TLS listener over TCP.
|
||||
pub struct TlsListener {
|
||||
listener: TcpListener,
|
||||
acceptor: TlsAcceptor,
|
||||
}
|
||||
|
||||
/// This implementation exists so that ROCKET_WORKERS=1 can make progress while
|
||||
/// a TLS handshake is being completed. It does this by returning `Ready` from
|
||||
/// `poll_accept()` as soon as we have a TCP connection and performing the
|
||||
/// handshake in the `AsyncRead` and `AsyncWrite` implementations.
|
||||
///
|
||||
/// A straight-forward implementation of this strategy results in none of the
|
||||
/// TLS information being available at the time the connection is "established",
|
||||
/// that is, when `poll_accept()` returns, since the handshake has yet to occur.
|
||||
/// Importantly, certificate information isn't available at the time that we
|
||||
/// request it.
|
||||
///
|
||||
/// The underlying problem is hyper's "Accept" trait. Were we to manage
|
||||
/// connections ourselves, we'd likely want to:
|
||||
///
|
||||
/// 1. Stop blocking the worker as soon as we have a TCP connection.
|
||||
/// 2. Perform the handshake in the background.
|
||||
/// 3. Give the connection to Rocket when/if the handshake is done.
|
||||
///
|
||||
/// See hyperium/hyper/issues/2321 for more details.
|
||||
///
|
||||
/// To work around this, we "lie" when `peer_certificates()` are requested and
|
||||
/// always return `Some(Certificates)`. Internally, `Certificates` is an
|
||||
/// `Arc<InitCell<Vec<CertificateDer>>>`, effectively a shared, thread-safe,
|
||||
/// `OnceCell`. The cell is initially empty and is filled as soon as the
|
||||
/// handshake is complete. If the certificate data were to be requested prior to
|
||||
/// this point, it would be empty. However, in Rocket, we only request
|
||||
/// certificate data when we have a `Request` object, which implies we're
|
||||
/// receiving payload data, which implies the TLS handshake has finished, so the
|
||||
/// certificate data as seen by a Rocket application will always be "fresh".
|
||||
pub struct TlsStream {
|
||||
remote: SocketAddr,
|
||||
state: TlsState,
|
||||
certs: Certificates,
|
||||
}
|
||||
|
||||
/// State of `TlsStream`.
|
||||
pub enum TlsState {
|
||||
/// The TLS handshake is taking place. We don't have a full connection yet.
|
||||
Handshaking(Accept<TcpStream>),
|
||||
/// TLS handshake completed successfully; we're getting payload data.
|
||||
Streaming(BareTlsStream<TcpStream>),
|
||||
}
|
||||
|
||||
/// TLS as ~configured by `TlsConfig` in `rocket` core.
|
||||
pub struct Config<R> {
|
||||
pub cert_chain: R,
|
||||
pub private_key: R,
|
||||
pub ciphersuites: Vec<rustls::SupportedCipherSuite>,
|
||||
pub prefer_server_order: bool,
|
||||
pub ca_certs: Option<R>,
|
||||
pub mandatory_mtls: bool,
|
||||
}
|
||||
|
||||
impl TlsListener {
|
||||
pub async fn bind<R>(addr: SocketAddr, mut c: Config<R>) -> crate::tls::Result<TlsListener>
|
||||
where R: io::BufRead
|
||||
{
|
||||
let provider = rustls::crypto::CryptoProvider {
|
||||
cipher_suites: c.ciphersuites,
|
||||
..rustls::crypto::ring::default_provider()
|
||||
};
|
||||
|
||||
let verifier = match c.ca_certs {
|
||||
Some(ref mut ca_certs) => {
|
||||
let ca_roots = Arc::new(load_ca_certs(ca_certs)?);
|
||||
let verifier = WebPkiClientVerifier::builder(ca_roots);
|
||||
match c.mandatory_mtls {
|
||||
true => verifier.build()?,
|
||||
false => verifier.allow_unauthenticated().build()?,
|
||||
}
|
||||
},
|
||||
None => WebPkiClientVerifier::no_client_auth(),
|
||||
};
|
||||
|
||||
let key = load_key(&mut c.private_key)?;
|
||||
let cert_chain = load_cert_chain(&mut c.cert_chain)?;
|
||||
let mut config = ServerConfig::builder_with_provider(Arc::new(provider))
|
||||
.with_safe_default_protocol_versions()?
|
||||
.with_client_cert_verifier(verifier)
|
||||
.with_single_cert(cert_chain, key)?;
|
||||
|
||||
config.ignore_client_order = c.prefer_server_order;
|
||||
config.session_storage = ServerSessionMemoryCache::new(1024);
|
||||
config.ticketer = rustls::crypto::ring::Ticketer::new()?;
|
||||
config.alpn_protocols = vec![b"http/1.1".to_vec()];
|
||||
if cfg!(feature = "http2") {
|
||||
config.alpn_protocols.insert(0, b"h2".to_vec());
|
||||
}
|
||||
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
let acceptor = TlsAcceptor::from(Arc::new(config));
|
||||
Ok(TlsListener { listener, acceptor })
|
||||
}
|
||||
}
|
||||
|
||||
impl Listener for TlsListener {
|
||||
type Connection = TlsStream;
|
||||
|
||||
fn local_addr(&self) -> Option<SocketAddr> {
|
||||
self.listener.local_addr().ok()
|
||||
}
|
||||
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<io::Result<Self::Connection>> {
|
||||
match futures::ready!(self.listener.poll_accept(cx)) {
|
||||
Ok((io, addr)) => Poll::Ready(Ok(TlsStream {
|
||||
remote: addr,
|
||||
state: TlsState::Handshaking(self.acceptor.accept(io)),
|
||||
// These are empty and filled in after handshake is complete.
|
||||
certs: Certificates::default(),
|
||||
})),
|
||||
Err(e) => Poll::Ready(Err(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection for TlsStream {
|
||||
fn peer_address(&self) -> Option<SocketAddr> {
|
||||
Some(self.remote)
|
||||
}
|
||||
|
||||
fn enable_nodelay(&self) -> io::Result<()> {
|
||||
// If `Handshaking` is `None`, it either failed, so we returned an `Err`
|
||||
// from `poll_accept()` and there's no connection to enable `NODELAY`
|
||||
// on, or it succeeded, so we're in the `Streaming` stage and we have
|
||||
// infallible access to the connection.
|
||||
match &self.state {
|
||||
TlsState::Handshaking(accept) => match accept.get_ref() {
|
||||
None => Ok(()),
|
||||
Some(s) => s.enable_nodelay(),
|
||||
},
|
||||
TlsState::Streaming(stream) => stream.get_ref().0.enable_nodelay()
|
||||
}
|
||||
}
|
||||
|
||||
fn peer_certificates(&self) -> Option<Certificates> {
|
||||
Some(self.certs.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl TlsStream {
|
||||
fn poll_accept_then<F, T>(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
mut f: F
|
||||
) -> Poll<io::Result<T>>
|
||||
where F: FnMut(&mut BareTlsStream<TcpStream>, &mut Context<'_>) -> Poll<io::Result<T>>
|
||||
{
|
||||
loop {
|
||||
match self.state {
|
||||
TlsState::Handshaking(ref mut accept) => {
|
||||
match futures::ready!(Pin::new(accept).poll(cx)) {
|
||||
Ok(stream) => {
|
||||
if let Some(peer_certs) = stream.get_ref().1.peer_certificates() {
|
||||
self.certs.set(peer_certs.into_iter()
|
||||
.map(|v| CertificateDer(v.clone().into_owned()))
|
||||
.collect());
|
||||
}
|
||||
|
||||
self.state = TlsState::Streaming(stream);
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("tls handshake with {} failed: {}", self.remote, e);
|
||||
return Poll::Ready(Err(e));
|
||||
}
|
||||
}
|
||||
},
|
||||
TlsState::Streaming(ref mut stream) => return f(stream, cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for TlsStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.poll_accept_then(cx, |stream, cx| Pin::new(stream).poll_read(cx, buf))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TlsStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.poll_accept_then(cx, |stream, cx| Pin::new(stream).poll_write(cx, buf))
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
match &mut self.state {
|
||||
TlsState::Handshaking(accept) => match accept.get_mut() {
|
||||
Some(io) => Pin::new(io).poll_flush(cx),
|
||||
None => Poll::Ready(Ok(())),
|
||||
}
|
||||
TlsState::Streaming(stream) => Pin::new(stream).poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
match &mut self.state {
|
||||
TlsState::Handshaking(accept) => match accept.get_mut() {
|
||||
Some(io) => Pin::new(io).poll_shutdown(cx),
|
||||
None => Poll::Ready(Ok(())),
|
||||
}
|
||||
TlsState::Streaming(stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,11 +0,0 @@
|
|||
mod listener;
|
||||
|
||||
#[cfg(feature = "mtls")]
|
||||
pub mod mtls;
|
||||
|
||||
pub use rustls;
|
||||
pub use listener::{TlsListener, Config};
|
||||
pub mod util;
|
||||
pub mod error;
|
||||
|
||||
pub use error::Result;
|
|
@ -20,23 +20,36 @@ rust-version = "1.64"
|
|||
all-features = true
|
||||
|
||||
[features]
|
||||
default = ["http2"]
|
||||
tls = ["rocket_http/tls"]
|
||||
mtls = ["rocket_http/mtls", "tls"]
|
||||
http2 = ["rocket_http/http2"]
|
||||
secrets = ["rocket_http/private-cookies"]
|
||||
json = ["serde_json", "tokio/io-util"]
|
||||
msgpack = ["rmp-serde", "tokio/io-util"]
|
||||
default = ["http2", "tokio-macros"]
|
||||
http2 = ["hyper/http2", "hyper-util/http2"]
|
||||
secrets = ["cookie/private", "cookie/key-expansion"]
|
||||
json = ["serde_json"]
|
||||
msgpack = ["rmp-serde"]
|
||||
uuid = ["uuid_", "rocket_http/uuid"]
|
||||
tls = ["rustls", "tokio-rustls", "rustls-pemfile"]
|
||||
mtls = ["tls", "x509-parser"]
|
||||
tokio-macros = ["tokio/macros"]
|
||||
|
||||
[dependencies]
|
||||
# Serialization dependencies.
|
||||
# Optional serialization dependencies.
|
||||
serde_json = { version = "1.0.26", optional = true }
|
||||
rmp-serde = { version = "1", optional = true }
|
||||
uuid_ = { package = "uuid", version = "1", optional = true, features = ["serde"] }
|
||||
|
||||
# Optional TLS dependencies
|
||||
rustls = { version = "0.22", optional = true }
|
||||
tokio-rustls = { version = "0.25", optional = true }
|
||||
rustls-pemfile = { version = "2.0.0", optional = true }
|
||||
|
||||
# Optional MTLS dependencies
|
||||
x509-parser = { version = "0.13", optional = true }
|
||||
|
||||
# Hyper dependencies
|
||||
http = "1"
|
||||
bytes = "1.4"
|
||||
hyper = { version = "1.1", default-features = false, features = ["http1", "server"] }
|
||||
|
||||
# Non-optional, core dependencies from here on out.
|
||||
futures = { version = "0.3.0", default-features = false, features = ["std"] }
|
||||
yansi = { version = "1.0.0-rc", features = ["detect-tty"] }
|
||||
log = { version = "0.4", features = ["std"] }
|
||||
num_cpus = "1.0"
|
||||
|
@ -44,11 +57,11 @@ time = { version = "0.3", features = ["macros", "parsing"] }
|
|||
memchr = "2" # TODO: Use pear instead.
|
||||
binascii = "0.1"
|
||||
ref-cast = "1.0"
|
||||
atomic = "0.5"
|
||||
ref-swap = "0.1.2"
|
||||
parking_lot = "0.12"
|
||||
ubyte = {version = "0.10.2", features = ["serde"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
figment = { version = "0.10.6", features = ["toml", "env"] }
|
||||
figment = { version = "0.10.13", features = ["toml", "env"] }
|
||||
rand = "0.8"
|
||||
either = "1"
|
||||
pin-project-lite = "0.2"
|
||||
|
@ -58,8 +71,25 @@ async-trait = "0.1.43"
|
|||
async-stream = "0.3.2"
|
||||
multer = { version = "3.0.0", features = ["tokio-io"] }
|
||||
tokio-stream = { version = "0.1.6", features = ["signal", "time"] }
|
||||
cookie = { version = "0.18", features = ["percent-encode"] }
|
||||
futures = { version = "0.3.30", default-features = false, features = ["std"] }
|
||||
state = "0.6"
|
||||
|
||||
[dependencies.hyper-util]
|
||||
git = "https://github.com/SergioBenitez/hyper-util.git"
|
||||
branch = "fix-readversion"
|
||||
default-features = false
|
||||
features = ["http1", "server", "tokio"]
|
||||
|
||||
[dependencies.tokio]
|
||||
version = "1.35.1"
|
||||
features = ["rt-multi-thread", "net", "io-util", "fs", "time", "sync", "signal", "parking_lot"]
|
||||
|
||||
[dependencies.tokio-util]
|
||||
version = "0.7"
|
||||
default-features = false
|
||||
features = ["io"]
|
||||
|
||||
[dependencies.rocket_codegen]
|
||||
version = "0.6.0-dev"
|
||||
path = "../codegen"
|
||||
|
@ -69,21 +99,13 @@ version = "0.6.0-dev"
|
|||
path = "../http"
|
||||
features = ["serde"]
|
||||
|
||||
[dependencies.tokio]
|
||||
version = "1.6.1"
|
||||
features = ["fs", "io-std", "io-util", "rt-multi-thread", "sync", "signal", "macros"]
|
||||
|
||||
[dependencies.tokio-util]
|
||||
version = "0.7"
|
||||
default-features = false
|
||||
features = ["io"]
|
||||
|
||||
[dependencies.bytes]
|
||||
version = "1.0"
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
libc = "0.2.149"
|
||||
|
||||
[build-dependencies]
|
||||
version_check = "0.9.1"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1", features = ["macros", "io-std"] }
|
||||
figment = { version = "0.10", features = ["test"] }
|
||||
pretty_assertions = "1"
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
use std::net::{IpAddr, Ipv4Addr};
|
||||
|
||||
use figment::{Figment, Profile, Provider, Metadata, error::Result};
|
||||
use figment::providers::{Serialized, Env, Toml, Format};
|
||||
use figment::value::{Map, Dict, magic::RelativePathBuf};
|
||||
|
@ -12,9 +10,6 @@ use crate::request::{self, Request, FromRequest};
|
|||
use crate::http::uncased::Uncased;
|
||||
use crate::data::Limits;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
use crate::config::TlsConfig;
|
||||
|
||||
#[cfg(feature = "secrets")]
|
||||
use crate::config::SecretKey;
|
||||
|
||||
|
@ -66,10 +61,6 @@ pub struct Config {
|
|||
/// set to the extracting Figment's selected `Profile`._
|
||||
#[serde(skip)]
|
||||
pub profile: Profile,
|
||||
/// IP address to serve on. **(default: `127.0.0.1`)**
|
||||
pub address: IpAddr,
|
||||
/// Port to serve on. **(default: `8000`)**
|
||||
pub port: u16,
|
||||
/// Number of threads to use for executing futures. **(default: `num_cores`)**
|
||||
///
|
||||
/// _**Note:** Rocket only reads this value from sources in the [default
|
||||
|
@ -121,10 +112,6 @@ pub struct Config {
|
|||
pub temp_dir: RelativePathBuf,
|
||||
/// Keep-alive timeout in seconds; disabled when `0`. **(default: `5`)**
|
||||
pub keep_alive: u32,
|
||||
/// The TLS configuration, if any. **(default: `None`)**
|
||||
#[cfg(feature = "tls")]
|
||||
#[cfg_attr(nightly, doc(cfg(feature = "tls")))]
|
||||
pub tls: Option<TlsConfig>,
|
||||
/// The secret key for signing and encrypting. **(default: `0`)**
|
||||
///
|
||||
/// _**Note:** This field _always_ serializes as a 256-bit array of `0`s to
|
||||
|
@ -148,7 +135,6 @@ pub struct Config {
|
|||
/// use rocket::Config;
|
||||
///
|
||||
/// let config = Config {
|
||||
/// port: 1024,
|
||||
/// keep_alive: 10,
|
||||
/// ..Default::default()
|
||||
/// };
|
||||
|
@ -204,8 +190,6 @@ impl Config {
|
|||
pub fn debug_default() -> Config {
|
||||
Config {
|
||||
profile: Self::DEBUG_PROFILE,
|
||||
address: Ipv4Addr::new(127, 0, 0, 1).into(),
|
||||
port: 8000,
|
||||
workers: num_cpus::get(),
|
||||
max_blocking: 512,
|
||||
ident: Ident::default(),
|
||||
|
@ -214,8 +198,6 @@ impl Config {
|
|||
limits: Limits::default(),
|
||||
temp_dir: std::env::temp_dir().into(),
|
||||
keep_alive: 5,
|
||||
#[cfg(feature = "tls")]
|
||||
tls: None,
|
||||
#[cfg(feature = "secrets")]
|
||||
secret_key: SecretKey::zero(),
|
||||
shutdown: Shutdown::default(),
|
||||
|
@ -331,59 +313,6 @@ impl Config {
|
|||
Self::try_from(provider).unwrap_or_else(bail_with_config_error)
|
||||
}
|
||||
|
||||
/// Returns `true` if TLS is enabled.
|
||||
///
|
||||
/// TLS is enabled when the `tls` feature is enabled and TLS has been
|
||||
/// configured with at least one ciphersuite. Note that without changing
|
||||
/// defaults, all supported ciphersuites are enabled in the recommended
|
||||
/// configuration.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// let config = rocket::Config::default();
|
||||
/// if config.tls_enabled() {
|
||||
/// println!("TLS is enabled!");
|
||||
/// } else {
|
||||
/// println!("TLS is disabled.");
|
||||
/// }
|
||||
/// ```
|
||||
pub fn tls_enabled(&self) -> bool {
|
||||
#[cfg(feature = "tls")] {
|
||||
self.tls.as_ref().map_or(false, |tls| !tls.ciphers.is_empty())
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tls"))] { false }
|
||||
}
|
||||
|
||||
/// Returns `true` if mTLS is enabled.
|
||||
///
|
||||
/// mTLS is enabled when TLS is enabled ([`Config::tls_enabled()`]) _and_
|
||||
/// the `mtls` feature is enabled _and_ mTLS has been configured with a CA
|
||||
/// certificate chain.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// let config = rocket::Config::default();
|
||||
/// if config.mtls_enabled() {
|
||||
/// println!("mTLS is enabled!");
|
||||
/// } else {
|
||||
/// println!("mTLS is disabled.");
|
||||
/// }
|
||||
/// ```
|
||||
pub fn mtls_enabled(&self) -> bool {
|
||||
if !self.tls_enabled() {
|
||||
return false;
|
||||
}
|
||||
|
||||
#[cfg(feature = "mtls")] {
|
||||
self.tls.as_ref().map_or(false, |tls| tls.mutual.is_some())
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "mtls"))] { false }
|
||||
}
|
||||
|
||||
#[cfg(feature = "secrets")]
|
||||
pub(crate) fn known_secret_key_used(&self) -> bool {
|
||||
const KNOWN_SECRET_KEYS: &'static [&'static str] = &[
|
||||
|
@ -420,8 +349,6 @@ impl Config {
|
|||
|
||||
self.trace_print(figment);
|
||||
launch_meta!("{}Configured for {}.", "🔧 ".emoji(), self.profile.underline());
|
||||
launch_meta_!("address: {}", self.address.paint(VAL));
|
||||
launch_meta_!("port: {}", self.port.paint(VAL));
|
||||
launch_meta_!("workers: {}", self.workers.paint(VAL));
|
||||
launch_meta_!("max blocking threads: {}", self.max_blocking.paint(VAL));
|
||||
launch_meta_!("ident: {}", self.ident.paint(VAL));
|
||||
|
@ -445,12 +372,6 @@ impl Config {
|
|||
ka => launch_meta_!("keep-alive: {}{}", ka.paint(VAL), "s".paint(VAL)),
|
||||
}
|
||||
|
||||
match (self.tls_enabled(), self.mtls_enabled()) {
|
||||
(true, true) => launch_meta_!("tls: {}", "enabled w/mtls".paint(VAL)),
|
||||
(true, false) => launch_meta_!("tls: {} w/o mtls", "enabled".paint(VAL)),
|
||||
(false, _) => launch_meta_!("tls: {}", "disabled".paint(VAL)),
|
||||
}
|
||||
|
||||
launch_meta_!("shutdown: {}", self.shutdown.paint(VAL));
|
||||
launch_meta_!("log level: {}", self.log_level.paint(VAL));
|
||||
launch_meta_!("cli colors: {}", self.cli_colors.paint(VAL));
|
||||
|
@ -519,12 +440,6 @@ impl Config {
|
|||
/// This isn't `pub` because setting it directly does nothing.
|
||||
const PROFILE: &'static str = "profile";
|
||||
|
||||
/// The stringy parameter name for setting/extracting [`Config::address`].
|
||||
pub const ADDRESS: &'static str = "address";
|
||||
|
||||
/// The stringy parameter name for setting/extracting [`Config::port`].
|
||||
pub const PORT: &'static str = "port";
|
||||
|
||||
/// The stringy parameter name for setting/extracting [`Config::workers`].
|
||||
pub const WORKERS: &'static str = "workers";
|
||||
|
||||
|
@ -546,9 +461,6 @@ impl Config {
|
|||
/// The stringy parameter name for setting/extracting [`Config::limits`].
|
||||
pub const LIMITS: &'static str = "limits";
|
||||
|
||||
/// The stringy parameter name for setting/extracting [`Config::tls`].
|
||||
pub const TLS: &'static str = "tls";
|
||||
|
||||
/// The stringy parameter name for setting/extracting [`Config::secret_key`].
|
||||
pub const SECRET_KEY: &'static str = "secret_key";
|
||||
|
||||
|
@ -566,9 +478,10 @@ impl Config {
|
|||
|
||||
/// An array of all of the stringy parameter names.
|
||||
pub const PARAMETERS: &'static [&'static str] = &[
|
||||
Self::ADDRESS, Self::PORT, Self::WORKERS, Self::MAX_BLOCKING, Self::KEEP_ALIVE,
|
||||
Self::IDENT, Self::IP_HEADER, Self::PROXY_PROTO_HEADER, Self::LIMITS, Self::TLS,
|
||||
Self::SECRET_KEY, Self::TEMP_DIR, Self::LOG_LEVEL, Self::SHUTDOWN, Self::CLI_COLORS,
|
||||
Self::WORKERS, Self::MAX_BLOCKING, Self::KEEP_ALIVE, Self::IDENT,
|
||||
Self::IP_HEADER, Self::PROXY_PROTO_HEADER, Self::LIMITS,
|
||||
Self::SECRET_KEY, Self::TEMP_DIR, Self::LOG_LEVEL, Self::SHUTDOWN,
|
||||
Self::CLI_COLORS,
|
||||
];
|
||||
}
|
||||
|
||||
|
|
|
@ -117,9 +117,6 @@ mod shutdown;
|
|||
mod cli_colors;
|
||||
mod http_header;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
mod tls;
|
||||
|
||||
#[cfg(feature = "secrets")]
|
||||
mod secret_key;
|
||||
|
||||
|
@ -132,12 +129,6 @@ pub use shutdown::Shutdown;
|
|||
pub use ident::Ident;
|
||||
pub use cli_colors::CliColors;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
pub use tls::{TlsConfig, CipherSuite};
|
||||
|
||||
#[cfg(feature = "mtls")]
|
||||
pub use tls::MutualTls;
|
||||
|
||||
#[cfg(feature = "secrets")]
|
||||
pub use secret_key::SecretKey;
|
||||
|
||||
|
@ -146,7 +137,6 @@ pub use shutdown::Sig;
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::Ipv4Addr;
|
||||
use figment::{Figment, Profile};
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
|
@ -202,9 +192,7 @@ mod tests {
|
|||
figment::Jail::expect_with(|jail| {
|
||||
jail.create_file("Rocket.toml", r#"
|
||||
[default]
|
||||
address = "1.2.3.4"
|
||||
ident = "Something Cool"
|
||||
port = 1234
|
||||
workers = 20
|
||||
keep_alive = 10
|
||||
log_level = "off"
|
||||
|
@ -213,8 +201,6 @@ mod tests {
|
|||
|
||||
let config = Config::from(Config::figment());
|
||||
assert_eq!(config, Config {
|
||||
address: Ipv4Addr::new(1, 2, 3, 4).into(),
|
||||
port: 1234,
|
||||
workers: 20,
|
||||
ident: ident!("Something Cool"),
|
||||
keep_alive: 10,
|
||||
|
@ -225,9 +211,7 @@ mod tests {
|
|||
|
||||
jail.create_file("Rocket.toml", r#"
|
||||
[global]
|
||||
address = "1.2.3.4"
|
||||
ident = "Something Else Cool"
|
||||
port = 1234
|
||||
workers = 20
|
||||
keep_alive = 10
|
||||
log_level = "off"
|
||||
|
@ -236,8 +220,6 @@ mod tests {
|
|||
|
||||
let config = Config::from(Config::figment());
|
||||
assert_eq!(config, Config {
|
||||
address: Ipv4Addr::new(1, 2, 3, 4).into(),
|
||||
port: 1234,
|
||||
workers: 20,
|
||||
ident: ident!("Something Else Cool"),
|
||||
keep_alive: 10,
|
||||
|
@ -249,8 +231,6 @@ mod tests {
|
|||
jail.set_env("ROCKET_CONFIG", "Other.toml");
|
||||
jail.create_file("Other.toml", r#"
|
||||
[default]
|
||||
address = "1.2.3.4"
|
||||
port = 1234
|
||||
workers = 20
|
||||
keep_alive = 10
|
||||
log_level = "off"
|
||||
|
@ -259,8 +239,6 @@ mod tests {
|
|||
|
||||
let config = Config::from(Config::figment());
|
||||
assert_eq!(config, Config {
|
||||
address: Ipv4Addr::new(1, 2, 3, 4).into(),
|
||||
port: 1234,
|
||||
workers: 20,
|
||||
keep_alive: 10,
|
||||
log_level: LogLevel::Off,
|
||||
|
@ -367,228 +345,6 @@ mod tests {
|
|||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "tls")]
|
||||
fn test_tls_config_from_file() {
|
||||
use crate::config::{TlsConfig, CipherSuite, Ident, Shutdown};
|
||||
|
||||
figment::Jail::expect_with(|jail| {
|
||||
jail.create_file("Rocket.toml", r#"
|
||||
[global]
|
||||
shutdown.ctrlc = 0
|
||||
ident = false
|
||||
|
||||
[global.tls]
|
||||
certs = "/ssl/cert.pem"
|
||||
key = "/ssl/key.pem"
|
||||
|
||||
[global.limits]
|
||||
forms = "1mib"
|
||||
json = "10mib"
|
||||
stream = "50kib"
|
||||
"#)?;
|
||||
|
||||
let config = Config::from(Config::figment());
|
||||
assert_eq!(config, Config {
|
||||
shutdown: Shutdown { ctrlc: false, ..Default::default() },
|
||||
ident: Ident::none(),
|
||||
tls: Some(TlsConfig::from_paths("/ssl/cert.pem", "/ssl/key.pem")),
|
||||
limits: Limits::default()
|
||||
.limit("forms", 1.mebibytes())
|
||||
.limit("json", 10.mebibytes())
|
||||
.limit("stream", 50.kibibytes()),
|
||||
..Config::default()
|
||||
});
|
||||
|
||||
jail.create_file("Rocket.toml", r#"
|
||||
[global.tls]
|
||||
certs = "cert.pem"
|
||||
key = "key.pem"
|
||||
"#)?;
|
||||
|
||||
let config = Config::from(Config::figment());
|
||||
assert_eq!(config, Config {
|
||||
tls: Some(TlsConfig::from_paths(
|
||||
jail.directory().join("cert.pem"),
|
||||
jail.directory().join("key.pem")
|
||||
)),
|
||||
..Config::default()
|
||||
});
|
||||
|
||||
jail.create_file("Rocket.toml", r#"
|
||||
[global.tls]
|
||||
certs = "cert.pem"
|
||||
key = "key.pem"
|
||||
prefer_server_cipher_order = true
|
||||
ciphers = [
|
||||
"TLS_CHACHA20_POLY1305_SHA256",
|
||||
"TLS_AES_256_GCM_SHA384",
|
||||
"TLS_AES_128_GCM_SHA256",
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
|
||||
]
|
||||
"#)?;
|
||||
|
||||
let config = Config::from(Config::figment());
|
||||
let cert_path = jail.directory().join("cert.pem");
|
||||
let key_path = jail.directory().join("key.pem");
|
||||
assert_eq!(config, Config {
|
||||
tls: Some(TlsConfig::from_paths(cert_path, key_path)
|
||||
.with_preferred_server_cipher_order(true)
|
||||
.with_ciphers([
|
||||
CipherSuite::TLS_CHACHA20_POLY1305_SHA256,
|
||||
CipherSuite::TLS_AES_256_GCM_SHA384,
|
||||
CipherSuite::TLS_AES_128_GCM_SHA256,
|
||||
CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
])),
|
||||
..Config::default()
|
||||
});
|
||||
|
||||
jail.create_file("Rocket.toml", r#"
|
||||
[global]
|
||||
shutdown.ctrlc = 0
|
||||
ident = false
|
||||
|
||||
[global.tls]
|
||||
certs = "/ssl/cert.pem"
|
||||
key = "/ssl/key.pem"
|
||||
|
||||
[global.limits]
|
||||
forms = "1mib"
|
||||
json = "10mib"
|
||||
stream = "50kib"
|
||||
"#)?;
|
||||
|
||||
let config = Config::from(Config::figment());
|
||||
assert_eq!(config, Config {
|
||||
shutdown: Shutdown { ctrlc: false, ..Default::default() },
|
||||
ident: Ident::none(),
|
||||
tls: Some(TlsConfig::from_paths("/ssl/cert.pem", "/ssl/key.pem")),
|
||||
limits: Limits::default()
|
||||
.limit("forms", 1.mebibytes())
|
||||
.limit("json", 10.mebibytes())
|
||||
.limit("stream", 50.kibibytes()),
|
||||
..Config::default()
|
||||
});
|
||||
|
||||
jail.create_file("Rocket.toml", r#"
|
||||
[global.tls]
|
||||
certs = "cert.pem"
|
||||
key = "key.pem"
|
||||
"#)?;
|
||||
|
||||
let config = Config::from(Config::figment());
|
||||
assert_eq!(config, Config {
|
||||
tls: Some(TlsConfig::from_paths(
|
||||
jail.directory().join("cert.pem"),
|
||||
jail.directory().join("key.pem")
|
||||
)),
|
||||
..Config::default()
|
||||
});
|
||||
|
||||
jail.create_file("Rocket.toml", r#"
|
||||
[global.tls]
|
||||
certs = "cert.pem"
|
||||
key = "key.pem"
|
||||
prefer_server_cipher_order = true
|
||||
ciphers = [
|
||||
"TLS_CHACHA20_POLY1305_SHA256",
|
||||
"TLS_AES_256_GCM_SHA384",
|
||||
"TLS_AES_128_GCM_SHA256",
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
|
||||
]
|
||||
"#)?;
|
||||
|
||||
let config = Config::from(Config::figment());
|
||||
let cert_path = jail.directory().join("cert.pem");
|
||||
let key_path = jail.directory().join("key.pem");
|
||||
assert_eq!(config, Config {
|
||||
tls: Some(TlsConfig::from_paths(cert_path, key_path)
|
||||
.with_preferred_server_cipher_order(true)
|
||||
.with_ciphers([
|
||||
CipherSuite::TLS_CHACHA20_POLY1305_SHA256,
|
||||
CipherSuite::TLS_AES_256_GCM_SHA384,
|
||||
CipherSuite::TLS_AES_128_GCM_SHA256,
|
||||
CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
])),
|
||||
..Config::default()
|
||||
});
|
||||
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "mtls")]
|
||||
fn test_mtls_config() {
|
||||
use std::path::Path;
|
||||
|
||||
figment::Jail::expect_with(|jail| {
|
||||
jail.create_file("Rocket.toml", r#"
|
||||
[default.tls]
|
||||
certs = "/ssl/cert.pem"
|
||||
key = "/ssl/key.pem"
|
||||
"#)?;
|
||||
|
||||
let config = Config::from(Config::figment());
|
||||
assert!(config.tls.is_some());
|
||||
assert!(config.tls.as_ref().unwrap().mutual.is_none());
|
||||
assert!(config.tls_enabled());
|
||||
assert!(!config.mtls_enabled());
|
||||
|
||||
jail.create_file("Rocket.toml", r#"
|
||||
[default.tls]
|
||||
certs = "/ssl/cert.pem"
|
||||
key = "/ssl/key.pem"
|
||||
mutual = { ca_certs = "/ssl/ca.pem" }
|
||||
"#)?;
|
||||
|
||||
let config = Config::from(Config::figment());
|
||||
assert!(config.tls_enabled());
|
||||
assert!(config.mtls_enabled());
|
||||
|
||||
let mtls = config.tls.as_ref().unwrap().mutual.as_ref().unwrap();
|
||||
assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem"));
|
||||
assert!(!mtls.mandatory);
|
||||
|
||||
jail.create_file("Rocket.toml", r#"
|
||||
[default.tls]
|
||||
certs = "/ssl/cert.pem"
|
||||
key = "/ssl/key.pem"
|
||||
|
||||
[default.tls.mutual]
|
||||
ca_certs = "/ssl/ca.pem"
|
||||
mandatory = true
|
||||
"#)?;
|
||||
|
||||
let config = Config::from(Config::figment());
|
||||
let mtls = config.tls.as_ref().unwrap().mutual.as_ref().unwrap();
|
||||
assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem"));
|
||||
assert!(mtls.mandatory);
|
||||
|
||||
jail.create_file("Rocket.toml", r#"
|
||||
[default.tls]
|
||||
certs = "/ssl/cert.pem"
|
||||
key = "/ssl/key.pem"
|
||||
mutual = { ca_certs = "relative/ca.pem" }
|
||||
"#)?;
|
||||
|
||||
let config = Config::from(Config::figment());
|
||||
let mtls = config.tls.as_ref().unwrap().mutual().unwrap();
|
||||
assert_eq!(mtls.ca_certs().unwrap_left(),
|
||||
jail.directory().join("relative/ca.pem"));
|
||||
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_profiles_merge() {
|
||||
figment::Jail::expect_with(|jail| {
|
||||
|
@ -629,42 +385,41 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "tls")]
|
||||
fn test_env_vars_merge() {
|
||||
use crate::config::{TlsConfig, Ident};
|
||||
use crate::config::{Ident, Shutdown};
|
||||
|
||||
figment::Jail::expect_with(|jail| {
|
||||
jail.set_env("ROCKET_PORT", 9999);
|
||||
jail.set_env("ROCKET_KEEP_ALIVE", 9999);
|
||||
let config = Config::from(Config::figment());
|
||||
assert_eq!(config, Config {
|
||||
port: 9999,
|
||||
keep_alive: 9999,
|
||||
..Config::default()
|
||||
});
|
||||
|
||||
jail.set_env("ROCKET_TLS", r#"{certs="certs.pem"}"#);
|
||||
jail.set_env("ROCKET_SHUTDOWN", r#"{grace=7}"#);
|
||||
let first_figment = Config::figment();
|
||||
jail.set_env("ROCKET_TLS", r#"{key="key.pem"}"#);
|
||||
jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=10}"#);
|
||||
let prev_figment = Config::figment().join(&first_figment);
|
||||
let config = Config::from(&prev_figment);
|
||||
assert_eq!(config, Config {
|
||||
port: 9999,
|
||||
tls: Some(TlsConfig::from_paths("certs.pem", "key.pem")),
|
||||
keep_alive: 9999,
|
||||
shutdown: Shutdown { grace: 7, mercy: 10, ..Default::default() },
|
||||
..Config::default()
|
||||
});
|
||||
|
||||
jail.set_env("ROCKET_TLS", r#"{certs="new.pem"}"#);
|
||||
jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=20}"#);
|
||||
let config = Config::from(Config::figment().join(&prev_figment));
|
||||
assert_eq!(config, Config {
|
||||
port: 9999,
|
||||
tls: Some(TlsConfig::from_paths("new.pem", "key.pem")),
|
||||
keep_alive: 9999,
|
||||
shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() },
|
||||
..Config::default()
|
||||
});
|
||||
|
||||
jail.set_env("ROCKET_LIMITS", r#"{stream=100kiB}"#);
|
||||
let config = Config::from(Config::figment().join(&prev_figment));
|
||||
assert_eq!(config, Config {
|
||||
port: 9999,
|
||||
tls: Some(TlsConfig::from_paths("new.pem", "key.pem")),
|
||||
keep_alive: 9999,
|
||||
shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() },
|
||||
limits: Limits::default().limit("stream", 100.kibibytes()),
|
||||
..Config::default()
|
||||
});
|
||||
|
@ -672,8 +427,8 @@ mod tests {
|
|||
jail.set_env("ROCKET_IDENT", false);
|
||||
let config = Config::from(Config::figment().join(&prev_figment));
|
||||
assert_eq!(config, Config {
|
||||
port: 9999,
|
||||
tls: Some(TlsConfig::from_paths("new.pem", "key.pem")),
|
||||
keep_alive: 9999,
|
||||
shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() },
|
||||
limits: Limits::default().limit("stream", 100.kibibytes()),
|
||||
ident: Ident::none(),
|
||||
..Config::default()
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use std::fmt;
|
||||
|
||||
use cookie::Key;
|
||||
use serde::{de, ser, Deserialize, Serialize};
|
||||
|
||||
use crate::http::private::cookie::Key;
|
||||
use crate::request::{Outcome, Request, FromRequest};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::fmt;
|
||||
use std::{fmt, time::Duration};
|
||||
|
||||
#[cfg(unix)]
|
||||
use std::collections::HashSet;
|
||||
|
@ -291,6 +291,14 @@ impl Default for Shutdown {
|
|||
}
|
||||
|
||||
impl Shutdown {
|
||||
pub(crate) fn grace(&self) -> Duration {
|
||||
Duration::from_secs(self.grace as u64)
|
||||
}
|
||||
|
||||
pub(crate) fn mercy(&self) -> Duration {
|
||||
Duration::from_secs(self.mercy as u64)
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub(crate) fn signal_stream(&self) -> Option<impl Stream<Item = Sig>> {
|
||||
use tokio_stream::{StreamExt, StreamMap, wrappers::SignalStream};
|
||||
|
|
|
@ -3,16 +3,16 @@ use std::task::{Context, Poll};
|
|||
use std::path::Path;
|
||||
use std::io::{self, Cursor};
|
||||
|
||||
use futures::ready;
|
||||
use futures::stream::Stream;
|
||||
use tokio::fs::File;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, ReadBuf, Take};
|
||||
use tokio_util::io::StreamReader;
|
||||
use futures::{ready, stream::Stream};
|
||||
use hyper::body::{Body, Bytes, Incoming as HyperBody};
|
||||
|
||||
use crate::http::hyper;
|
||||
use crate::ext::{PollExt, Chain};
|
||||
use crate::data::{Capped, N};
|
||||
use crate::http::hyper::body::Bytes;
|
||||
use crate::data::transform::Transform;
|
||||
use crate::util::Chain;
|
||||
|
||||
use super::peekable::Peekable;
|
||||
use super::transform::TransformBuf;
|
||||
|
@ -68,7 +68,7 @@ pub type RawReader<'r> = StreamReader<RawStream<'r>, Bytes>;
|
|||
/// Raw underlying data stream.
|
||||
pub enum RawStream<'r> {
|
||||
Empty,
|
||||
Body(&'r mut hyper::Body),
|
||||
Body(&'r mut HyperBody),
|
||||
Multipart(multer::Field<'r>),
|
||||
}
|
||||
|
||||
|
@ -154,8 +154,14 @@ impl<'r> DataStream<'r> {
|
|||
/// ```
|
||||
pub fn hint(&self) -> usize {
|
||||
let base = self.base();
|
||||
let buf_len = base.get_ref().get_ref().0.get_ref().len();
|
||||
std::cmp::min(buf_len, base.limit() as usize)
|
||||
if let (Some(cursor), _) = base.get_ref().get_ref() {
|
||||
let len = cursor.get_ref().len() as u64;
|
||||
let position = cursor.position().min(len);
|
||||
let remaining = len - position;
|
||||
remaining.min(base.limit()) as usize
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// A helper method to write the body of the request to any `AsyncWrite`
|
||||
|
@ -331,17 +337,25 @@ impl Stream for RawStream<'_> {
|
|||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
match self.get_mut() {
|
||||
RawStream::Body(body) => Pin::new(body).poll_next(cx)
|
||||
.map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)),
|
||||
RawStream::Multipart(mp) => Pin::new(mp).poll_next(cx)
|
||||
.map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)),
|
||||
// TODO: Expose trailer headers, somehow.
|
||||
RawStream::Body(body) => {
|
||||
Pin::new(body)
|
||||
.poll_frame(cx)
|
||||
.map_ok(|frame| frame.into_data().unwrap_or_else(|_| Bytes::new()))
|
||||
.map_err(io::Error::other)
|
||||
}
|
||||
RawStream::Multipart(s) => Pin::new(s).poll_next(cx).map_err(io::Error::other),
|
||||
RawStream::Empty => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
match self {
|
||||
RawStream::Body(body) => body.size_hint(),
|
||||
RawStream::Body(body) => {
|
||||
let hint = body.size_hint();
|
||||
let (lower, upper) = (hint.lower(), hint.upper());
|
||||
(lower as usize, upper.map(|x| x as usize))
|
||||
},
|
||||
RawStream::Multipart(mp) => mp.size_hint(),
|
||||
RawStream::Empty => (0, Some(0)),
|
||||
}
|
||||
|
@ -358,8 +372,8 @@ impl std::fmt::Display for RawStream<'_> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'r> From<&'r mut hyper::Body> for RawStream<'r> {
|
||||
fn from(value: &'r mut hyper::Body) -> Self {
|
||||
impl<'r> From<&'r mut HyperBody> for RawStream<'r> {
|
||||
fn from(value: &'r mut HyperBody) -> Self {
|
||||
Self::Body(value)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,8 +3,8 @@ use std::task::{Context, Poll};
|
|||
use std::pin::Pin;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use crate::http::hyper::upgrade::Upgraded;
|
||||
use hyper::upgrade::Upgraded;
|
||||
use hyper_util::rt::TokioIo;
|
||||
|
||||
/// A bidirectional, raw stream to the client.
|
||||
///
|
||||
|
@ -28,7 +28,7 @@ pub struct IoStream {
|
|||
|
||||
/// Just in case we want to add stream kinds in the future.
|
||||
enum IoStreamKind {
|
||||
Upgraded(Upgraded)
|
||||
Upgraded(TokioIo<Upgraded>)
|
||||
}
|
||||
|
||||
/// An upgraded connection I/O handler.
|
||||
|
@ -51,7 +51,7 @@ enum IoStreamKind {
|
|||
///
|
||||
/// #[rocket::async_trait]
|
||||
/// impl IoHandler for EchoHandler {
|
||||
/// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
|
||||
/// async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
|
||||
/// let (mut reader, mut writer) = io::split(io);
|
||||
/// io::copy(&mut reader, &mut writer).await?;
|
||||
/// Ok(())
|
||||
|
@ -68,13 +68,20 @@ enum IoStreamKind {
|
|||
#[crate::async_trait]
|
||||
pub trait IoHandler: Send {
|
||||
/// Performs the raw I/O.
|
||||
async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()>;
|
||||
async fn io(self: Box<Self>, io: IoStream) -> io::Result<()>;
|
||||
}
|
||||
|
||||
#[crate::async_trait]
|
||||
impl IoHandler for () {
|
||||
async fn io(self: Box<Self>, _: IoStream) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
impl From<Upgraded> for IoStream {
|
||||
fn from(io: Upgraded) -> Self {
|
||||
IoStream { kind: IoStreamKind::Upgraded(io) }
|
||||
IoStream { kind: IoStreamKind::Upgraded(TokioIo::new(io)) }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -178,7 +178,7 @@ impl<'a, 'b> DerefMut for TransformBuf<'a, 'b> {
|
|||
#[allow(deprecated)]
|
||||
mod tests {
|
||||
use std::hash::SipHasher;
|
||||
use std::sync::{Arc, atomic::{AtomicU64, AtomicU8}};
|
||||
use std::sync::{Arc, atomic::{AtomicU8, AtomicU64, Ordering}};
|
||||
|
||||
use parking_lot::Mutex;
|
||||
use ubyte::ToByteUnit;
|
||||
|
@ -264,41 +264,41 @@ mod tests {
|
|||
assert_eq!(bytes.len(), 8);
|
||||
let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
|
||||
let value = u64::from_be_bytes(bytes);
|
||||
hash1.store(value, atomic::Ordering::Release);
|
||||
hash1.store(value, Ordering::Release);
|
||||
})
|
||||
.chain_inspect(move |bytes| {
|
||||
assert_eq!(bytes.len(), 8);
|
||||
let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
|
||||
let value = u64::from_be_bytes(bytes);
|
||||
let prev = hash2.load(atomic::Ordering::Acquire);
|
||||
let prev = hash2.load(Ordering::Acquire);
|
||||
assert_eq!(prev, value);
|
||||
inspect2.fetch_add(1, atomic::Ordering::Release);
|
||||
inspect2.fetch_add(1, Ordering::Release);
|
||||
});
|
||||
})));
|
||||
|
||||
// Make sure nothing has happened yet.
|
||||
assert!(raw_data.lock().is_empty());
|
||||
assert_eq!(hash.load(atomic::Ordering::Acquire), 0);
|
||||
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 0);
|
||||
assert_eq!(hash.load(Ordering::Acquire), 0);
|
||||
assert_eq!(inspect2.load(Ordering::Acquire), 0);
|
||||
|
||||
// Check that nothing happens if the data isn't read.
|
||||
let client = Client::debug(rocket).unwrap();
|
||||
client.get("/").body("Hello, world!").dispatch();
|
||||
assert!(raw_data.lock().is_empty());
|
||||
assert_eq!(hash.load(atomic::Ordering::Acquire), 0);
|
||||
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 0);
|
||||
assert_eq!(hash.load(Ordering::Acquire), 0);
|
||||
assert_eq!(inspect2.load(Ordering::Acquire), 0);
|
||||
|
||||
// Check inspect + hash + inspect + inspect.
|
||||
client.post("/").body("Hello, world!").dispatch();
|
||||
assert_eq!(raw_data.lock().as_slice(), "Hello, world!".as_bytes());
|
||||
assert_eq!(hash.load(atomic::Ordering::Acquire), 0xae5020d7cf49d14f);
|
||||
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 1);
|
||||
assert_eq!(hash.load(Ordering::Acquire), 0xae5020d7cf49d14f);
|
||||
assert_eq!(inspect2.load(Ordering::Acquire), 1);
|
||||
|
||||
// Check inspect + hash + inspect + inspect, round 2.
|
||||
let string = "Rocket, Rocket, where art thee? Oh, tis in the sky, I see!";
|
||||
client.post("/").body(string).dispatch();
|
||||
assert_eq!(raw_data.lock().as_slice(), string.as_bytes());
|
||||
assert_eq!(hash.load(atomic::Ordering::Acquire), 0x323f9aa98f907faf);
|
||||
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 2);
|
||||
assert_eq!(hash.load(Ordering::Acquire), 0x323f9aa98f907faf);
|
||||
assert_eq!(inspect2.load(Ordering::Acquire), 2);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,193 @@
|
|||
use std::io;
|
||||
use std::mem::transmute;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Poll, Context};
|
||||
|
||||
use futures::future::BoxFuture;
|
||||
use http::request::Parts;
|
||||
use hyper::body::Incoming;
|
||||
use tokio::io::{AsyncRead, ReadBuf};
|
||||
|
||||
use crate::data::{Data, IoHandler};
|
||||
use crate::{Request, Response, Rocket, Orbit};
|
||||
|
||||
// TODO: Magic with trait async fn to get rid of the box pin.
|
||||
// TODO: Write safety proofs.
|
||||
|
||||
macro_rules! static_assert_covariance {
|
||||
($T:tt) => (
|
||||
const _: () = {
|
||||
fn _assert_covariance<'x: 'y, 'y>(x: &'y $T<'x>) -> &'y $T<'y> { x }
|
||||
};
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ErasedRequest {
|
||||
// XXX: SAFETY: This (dependent) field must come first due to drop order!
|
||||
request: Request<'static>,
|
||||
_rocket: Arc<Rocket<Orbit>>,
|
||||
_parts: Box<Parts>,
|
||||
}
|
||||
|
||||
impl Drop for ErasedRequest {
|
||||
fn drop(&mut self) { }
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ErasedResponse {
|
||||
// XXX: SAFETY: This (dependent) field must come first due to drop order!
|
||||
response: Response<'static>,
|
||||
_request: Arc<ErasedRequest>,
|
||||
_incoming: Box<Incoming>,
|
||||
}
|
||||
|
||||
impl Drop for ErasedResponse {
|
||||
fn drop(&mut self) { }
|
||||
}
|
||||
|
||||
pub struct ErasedIoHandler {
|
||||
// XXX: SAFETY: This (dependent) field must come first due to drop order!
|
||||
io: Box<dyn IoHandler + 'static>,
|
||||
_request: Arc<ErasedRequest>,
|
||||
}
|
||||
|
||||
impl Drop for ErasedIoHandler {
|
||||
fn drop(&mut self) { }
|
||||
}
|
||||
|
||||
impl ErasedRequest {
|
||||
pub fn new(
|
||||
rocket: Arc<Rocket<Orbit>>,
|
||||
parts: Parts,
|
||||
constructor: impl for<'r> FnOnce(
|
||||
&'r Rocket<Orbit>,
|
||||
&'r Parts
|
||||
) -> Request<'r>,
|
||||
) -> ErasedRequest {
|
||||
let rocket: Arc<Rocket<Orbit>> = rocket;
|
||||
let parts: Box<Parts> = Box::new(parts);
|
||||
let request: Request<'_> = {
|
||||
let rocket: &Rocket<Orbit> = &*rocket;
|
||||
let rocket: &'static Rocket<Orbit> = unsafe { transmute(rocket) };
|
||||
let parts: &Parts = &*parts;
|
||||
let parts: &'static Parts = unsafe { transmute(parts) };
|
||||
constructor(&rocket, &parts)
|
||||
};
|
||||
|
||||
ErasedRequest { _rocket: rocket, _parts: parts, request, }
|
||||
}
|
||||
|
||||
pub async fn into_response<T: Send + Sync + 'static>(
|
||||
self,
|
||||
incoming: Incoming,
|
||||
data_builder: impl for<'r> FnOnce(&'r mut Incoming) -> Data<'r>,
|
||||
preprocess: impl for<'r, 'x> FnOnce(
|
||||
&'r Rocket<Orbit>,
|
||||
&'r mut Request<'x>,
|
||||
&'r mut Data<'x>
|
||||
) -> BoxFuture<'r, T>,
|
||||
dispatch: impl for<'r> FnOnce(
|
||||
T,
|
||||
&'r Rocket<Orbit>,
|
||||
&'r Request<'r>,
|
||||
Data<'r>
|
||||
) -> BoxFuture<'r, Response<'r>>,
|
||||
) -> ErasedResponse {
|
||||
let mut incoming = Box::new(incoming);
|
||||
let mut data: Data<'_> = {
|
||||
let incoming: &mut Incoming = &mut *incoming;
|
||||
let incoming: &'static mut Incoming = unsafe { transmute(incoming) };
|
||||
data_builder(incoming)
|
||||
};
|
||||
|
||||
let mut parent = Arc::new(self);
|
||||
let token: T = {
|
||||
let parent: &mut ErasedRequest = Arc::get_mut(&mut parent).unwrap();
|
||||
let rocket: &Rocket<Orbit> = &*parent._rocket;
|
||||
let request: &mut Request<'_> = &mut parent.request;
|
||||
let data: &mut Data<'_> = &mut data;
|
||||
preprocess(rocket, request, data).await
|
||||
};
|
||||
|
||||
let parent = parent;
|
||||
let response: Response<'_> = {
|
||||
let parent: &ErasedRequest = &*parent;
|
||||
let parent: &'static ErasedRequest = unsafe { transmute(parent) };
|
||||
let rocket: &Rocket<Orbit> = &*parent._rocket;
|
||||
let request: &Request<'_> = &parent.request;
|
||||
dispatch(token, rocket, request, data).await
|
||||
};
|
||||
|
||||
ErasedResponse {
|
||||
_request: parent,
|
||||
_incoming: incoming,
|
||||
response: response,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ErasedResponse {
|
||||
pub fn inner<'a>(&'a self) -> &'a Response<'a> {
|
||||
static_assert_covariance!(Response);
|
||||
&self.response
|
||||
}
|
||||
|
||||
pub fn with_inner_mut<'a, T>(
|
||||
&'a mut self,
|
||||
f: impl for<'r> FnOnce(&'a mut Response<'r>) -> T
|
||||
) -> T {
|
||||
static_assert_covariance!(Response);
|
||||
f(&mut self.response)
|
||||
}
|
||||
|
||||
pub fn to_io_handler<'a>(
|
||||
&'a mut self,
|
||||
constructor: impl for<'r> FnOnce(
|
||||
&'r Request<'r>,
|
||||
&'a mut Response<'r>,
|
||||
) -> Option<Box<dyn IoHandler + 'r>>
|
||||
) -> Option<ErasedIoHandler> {
|
||||
let parent: Arc<ErasedRequest> = self._request.clone();
|
||||
let io: Option<Box<dyn IoHandler + '_>> = {
|
||||
let parent: &ErasedRequest = &*parent;
|
||||
let parent: &'static ErasedRequest = unsafe { transmute(parent) };
|
||||
let request: &Request<'_> = &parent.request;
|
||||
constructor(request, &mut self.response)
|
||||
};
|
||||
|
||||
io.map(|io| ErasedIoHandler { _request: parent, io })
|
||||
}
|
||||
}
|
||||
|
||||
impl ErasedIoHandler {
|
||||
pub fn with_inner_mut<'a, T: 'a>(
|
||||
&'a mut self,
|
||||
f: impl for<'r> FnOnce(&'a mut Box<dyn IoHandler + 'r>) -> T
|
||||
) -> T {
|
||||
fn _assert_covariance<'x: 'y, 'y>(
|
||||
x: &'y Box<dyn IoHandler + 'x>
|
||||
) -> &'y Box<dyn IoHandler + 'y> { x }
|
||||
|
||||
f(&mut self.io)
|
||||
}
|
||||
|
||||
pub fn take<'a>(&'a mut self) -> Box<dyn IoHandler + 'a> {
|
||||
fn _assert_covariance<'x: 'y, 'y>(
|
||||
x: &'y Box<dyn IoHandler + 'x>
|
||||
) -> &'y Box<dyn IoHandler + 'y> { x }
|
||||
|
||||
self.with_inner_mut(|handler| std::mem::replace(handler, Box::new(())))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for ErasedResponse {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.get_mut().with_inner_mut(|r| Pin::new(r.body_mut()).poll_read(cx, buf))
|
||||
}
|
||||
}
|
|
@ -74,11 +74,8 @@ pub struct Error {
|
|||
#[derive(Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum ErrorKind {
|
||||
/// Binding to the provided address/port failed.
|
||||
Bind(io::Error),
|
||||
/// Binding via TLS to the provided address/port failed.
|
||||
#[cfg(feature = "tls")]
|
||||
TlsBind(crate::http::tls::error::Error),
|
||||
/// Binding to the network interface failed.
|
||||
Bind(Box<dyn StdError + Send>),
|
||||
/// An I/O error occurred during launch.
|
||||
Io(io::Error),
|
||||
/// A valid [`Config`](crate::Config) could not be extracted from the
|
||||
|
@ -90,15 +87,10 @@ pub enum ErrorKind {
|
|||
FailedFairings(Vec<crate::fairing::Info>),
|
||||
/// Sentinels requested abort.
|
||||
SentinelAborts(Vec<crate::sentinel::Sentry>),
|
||||
/// The configuration profile is not debug but not secret key is configured.
|
||||
/// The configuration profile is not debug but no secret key is configured.
|
||||
InsecureSecretKey(Profile),
|
||||
/// Shutdown failed.
|
||||
Shutdown(
|
||||
/// The instance of Rocket that failed to shutdown.
|
||||
Arc<Rocket<Orbit>>,
|
||||
/// The error that occurred during shutdown, if any.
|
||||
Option<Box<dyn StdError + Send + Sync>>
|
||||
),
|
||||
/// Shutdown failed. Contains the Rocket instance that failed to shutdown.
|
||||
Shutdown(Arc<Rocket<Orbit>>),
|
||||
}
|
||||
|
||||
/// An error that occurs when a value was unexpectedly empty.
|
||||
|
@ -111,20 +103,24 @@ impl From<ErrorKind> for Error {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<figment::Error> for Error {
|
||||
fn from(e: figment::Error) -> Self {
|
||||
Error::new(ErrorKind::Config(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for Error {
|
||||
fn from(e: io::Error) -> Self {
|
||||
Error::new(ErrorKind::Io(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl Error {
|
||||
#[inline(always)]
|
||||
pub(crate) fn new(kind: ErrorKind) -> Error {
|
||||
Error { handled: AtomicBool::new(false), kind }
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn shutdown<E>(rocket: Arc<Rocket<Orbit>>, error: E) -> Error
|
||||
where E: Into<Option<crate::http::hyper::Error>>
|
||||
{
|
||||
let error = error.into().map(|e| Box::new(e) as Box<dyn StdError + Sync + Send>);
|
||||
Error::new(ErrorKind::Shutdown(rocket, error))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn was_handled(&self) -> bool {
|
||||
self.handled.load(Ordering::Acquire)
|
||||
|
@ -176,9 +172,9 @@ impl Error {
|
|||
self.mark_handled();
|
||||
match self.kind() {
|
||||
ErrorKind::Bind(ref e) => {
|
||||
error!("Rocket failed to bind network socket to given address/port.");
|
||||
error!("Binding to the network interface failed.");
|
||||
info_!("{}", e);
|
||||
"aborting due to socket bind error"
|
||||
"aborting due to bind error"
|
||||
}
|
||||
ErrorKind::Io(ref e) => {
|
||||
error!("Rocket failed to launch due to an I/O error.");
|
||||
|
@ -229,20 +225,10 @@ impl Error {
|
|||
|
||||
"aborting due to sentinel-triggered abort(s)"
|
||||
}
|
||||
ErrorKind::Shutdown(_, error) => {
|
||||
ErrorKind::Shutdown(_) => {
|
||||
error!("Rocket failed to shutdown gracefully.");
|
||||
if let Some(e) = error {
|
||||
info_!("{}", e);
|
||||
}
|
||||
|
||||
"aborting due to failed shutdown"
|
||||
}
|
||||
#[cfg(feature = "tls")]
|
||||
ErrorKind::TlsBind(e) => {
|
||||
error!("Rocket failed to bind via TLS to network socket.");
|
||||
info_!("{}", e);
|
||||
"aborting due to TLS bind error"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -260,10 +246,7 @@ impl fmt::Display for ErrorKind {
|
|||
ErrorKind::InsecureSecretKey(_) => "insecure secret key config".fmt(f),
|
||||
ErrorKind::Config(_) => "failed to extract configuration".fmt(f),
|
||||
ErrorKind::SentinelAborts(_) => "sentinel(s) aborted".fmt(f),
|
||||
ErrorKind::Shutdown(_, Some(e)) => write!(f, "shutdown failed: {e}"),
|
||||
ErrorKind::Shutdown(_, None) => "shutdown failed".fmt(f),
|
||||
#[cfg(feature = "tls")]
|
||||
ErrorKind::TlsBind(e) => write!(f, "TLS bind failed: {e}"),
|
||||
ErrorKind::Shutdown(_) => "shutdown failed".fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -308,3 +291,42 @@ impl fmt::Display for Empty {
|
|||
}
|
||||
|
||||
impl StdError for Empty { }
|
||||
|
||||
/// Log an error that occurs during request processing
|
||||
pub(crate) fn log_server_error(error: &Box<dyn StdError + Send + Sync>) {
|
||||
struct ServerError<'a>(&'a (dyn StdError + 'static));
|
||||
|
||||
impl fmt::Display for ServerError<'_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let error = &self.0;
|
||||
if let Some(e) = error.downcast_ref::<hyper::Error>() {
|
||||
write!(f, "request processing failed: {e}")?;
|
||||
} else if let Some(e) = error.downcast_ref::<io::Error>() {
|
||||
write!(f, "connection I/O error: ")?;
|
||||
|
||||
match e.kind() {
|
||||
io::ErrorKind::NotConnected => write!(f, "remote disconnected")?,
|
||||
io::ErrorKind::UnexpectedEof => write!(f, "remote sent early eof")?,
|
||||
io::ErrorKind::ConnectionReset
|
||||
| io::ErrorKind::ConnectionAborted
|
||||
| io::ErrorKind::BrokenPipe => write!(f, "terminated by remote")?,
|
||||
_ => write!(f, "{e}")?,
|
||||
}
|
||||
} else {
|
||||
write!(f, "http server error: {error}")?;
|
||||
}
|
||||
|
||||
if let Some(e) = error.source() {
|
||||
write!(f, " ({})", ServerError(e))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
if error.downcast_ref::<hyper::Error>().is_some() {
|
||||
warn!("{}", ServerError(&**error))
|
||||
} else {
|
||||
error!("{}", ServerError(&**error))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,404 +0,0 @@
|
|||
use std::{io, time::Duration};
|
||||
use std::task::{Poll, Context};
|
||||
use std::pin::Pin;
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::time::{sleep, Sleep};
|
||||
|
||||
use futures::stream::Stream;
|
||||
use futures::future::{self, Future, FutureExt};
|
||||
|
||||
pin_project! {
|
||||
pub struct ReaderStream<R> {
|
||||
#[pin]
|
||||
reader: Option<R>,
|
||||
buf: BytesMut,
|
||||
cap: usize,
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead> Stream for ReaderStream<R> {
|
||||
type Item = std::io::Result<Bytes>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
use tokio_util::io::poll_read_buf;
|
||||
|
||||
let mut this = self.as_mut().project();
|
||||
|
||||
let reader = match this.reader.as_pin_mut() {
|
||||
Some(r) => r,
|
||||
None => return Poll::Ready(None),
|
||||
};
|
||||
|
||||
if this.buf.capacity() == 0 {
|
||||
this.buf.reserve(*this.cap);
|
||||
}
|
||||
|
||||
match poll_read_buf(reader, cx, &mut this.buf) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(Err(err)) => {
|
||||
self.project().reader.set(None);
|
||||
Poll::Ready(Some(Err(err)))
|
||||
}
|
||||
Poll::Ready(Ok(0)) => {
|
||||
self.project().reader.set(None);
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Ready(Ok(_)) => {
|
||||
let chunk = this.buf.split();
|
||||
Poll::Ready(Some(Ok(chunk.freeze())))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AsyncReadExt: AsyncRead + Sized {
|
||||
fn into_bytes_stream(self, cap: usize) -> ReaderStream<Self> {
|
||||
ReaderStream { reader: Some(self), cap, buf: BytesMut::with_capacity(cap) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead> AsyncReadExt for T { }
|
||||
|
||||
pub trait PollExt<T, E> {
|
||||
fn map_err_ext<U, F>(self, f: F) -> Poll<Option<Result<T, U>>>
|
||||
where F: FnOnce(E) -> U;
|
||||
}
|
||||
|
||||
impl<T, E> PollExt<T, E> for Poll<Option<Result<T, E>>> {
|
||||
/// Changes the error value of this `Poll` with the closure provided.
|
||||
fn map_err_ext<U, F>(self, f: F) -> Poll<Option<Result<T, U>>>
|
||||
where F: FnOnce(E) -> U
|
||||
{
|
||||
match self {
|
||||
Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(Ok(t))),
|
||||
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(f(e)))),
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// Stream for the [`chain`](super::AsyncReadExt::chain) method.
|
||||
#[must_use = "streams do nothing unless polled"]
|
||||
pub struct Chain<T, U> {
|
||||
#[pin]
|
||||
first: T,
|
||||
#[pin]
|
||||
second: U,
|
||||
done_first: bool,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead, U: AsyncRead> Chain<T, U> {
|
||||
pub(crate) fn new(first: T, second: U) -> Self {
|
||||
Self { first, second, done_first: false }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead, U: AsyncRead> Chain<T, U> {
|
||||
/// Gets references to the underlying readers in this `Chain`.
|
||||
pub fn get_ref(&self) -> (&T, &U) {
|
||||
(&self.first, &self.second)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead, U: AsyncRead> AsyncRead for Chain<T, U> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let me = self.project();
|
||||
|
||||
if !*me.done_first {
|
||||
let init_rem = buf.remaining();
|
||||
futures::ready!(me.first.poll_read(cx, buf))?;
|
||||
if buf.remaining() == init_rem {
|
||||
*me.done_first = true;
|
||||
} else {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
me.second.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
enum State {
|
||||
/// I/O has not been cancelled. Proceed as normal.
|
||||
Active,
|
||||
/// I/O has been cancelled. See if we can finish before the timer expires.
|
||||
Grace(Pin<Box<Sleep>>),
|
||||
/// Grace period elapsed. Shutdown the connection, waiting for the timer
|
||||
/// until we force close.
|
||||
Mercy(Pin<Box<Sleep>>),
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// I/O that can be cancelled when a future `F` resolves.
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct CancellableIo<F, I> {
|
||||
#[pin]
|
||||
io: Option<I>,
|
||||
#[pin]
|
||||
trigger: future::Fuse<F>,
|
||||
state: State,
|
||||
grace: Duration,
|
||||
mercy: Duration,
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future, I: AsyncWrite> CancellableIo<F, I> {
|
||||
pub fn new(trigger: F, io: I, grace: Duration, mercy: Duration) -> Self {
|
||||
CancellableIo {
|
||||
grace, mercy,
|
||||
io: Some(io),
|
||||
trigger: trigger.fuse(),
|
||||
state: State::Active,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn io(&self) -> Option<&I> {
|
||||
self.io.as_ref()
|
||||
}
|
||||
|
||||
/// Run `do_io` while connection processing should continue.
|
||||
fn poll_trigger_then<T>(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll<io::Result<T>>,
|
||||
) -> Poll<io::Result<T>> {
|
||||
let mut me = self.as_mut().project();
|
||||
let io = match me.io.as_pin_mut() {
|
||||
Some(io) => io,
|
||||
None => return Poll::Ready(Err(gone())),
|
||||
};
|
||||
|
||||
loop {
|
||||
match me.state {
|
||||
State::Active => {
|
||||
if me.trigger.as_mut().poll(cx).is_ready() {
|
||||
*me.state = State::Grace(Box::pin(sleep(*me.grace)));
|
||||
} else {
|
||||
return do_io(io, cx);
|
||||
}
|
||||
}
|
||||
State::Grace(timer) => {
|
||||
if timer.as_mut().poll(cx).is_ready() {
|
||||
*me.state = State::Mercy(Box::pin(sleep(*me.mercy)));
|
||||
} else {
|
||||
return do_io(io, cx);
|
||||
}
|
||||
}
|
||||
State::Mercy(timer) => {
|
||||
if timer.as_mut().poll(cx).is_ready() {
|
||||
self.project().io.set(None);
|
||||
return Poll::Ready(Err(time_out()));
|
||||
} else {
|
||||
let result = futures::ready!(io.poll_shutdown(cx));
|
||||
self.project().io.set(None);
|
||||
return match result {
|
||||
Err(e) => Poll::Ready(Err(e)),
|
||||
Ok(()) => Poll::Ready(Err(gone()))
|
||||
};
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn time_out() -> io::Error {
|
||||
io::Error::new(io::ErrorKind::TimedOut, "Shutdown grace timed out")
|
||||
}
|
||||
|
||||
fn gone() -> io::Error {
|
||||
io::Error::new(io::ErrorKind::BrokenPipe, "IO driver has terminated")
|
||||
}
|
||||
|
||||
impl<F: Future, I: AsyncRead + AsyncWrite> AsyncRead for CancellableIo<F, I> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_read(cx, buf))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future, I: AsyncWrite> AsyncWrite for CancellableIo<F, I> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write(cx, buf))
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_flush(cx))
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_shutdown(cx))
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[io::IoSlice<'_>],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write_vectored(cx, bufs))
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.io().map(|io| io.is_write_vectored()).unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
use crate::http::private::{Listener, Connection, Certificates};
|
||||
|
||||
impl<F: Future, C: Connection> Connection for CancellableIo<F, C> {
|
||||
fn peer_address(&self) -> Option<std::net::SocketAddr> {
|
||||
self.io().and_then(|io| io.peer_address())
|
||||
}
|
||||
|
||||
fn peer_certificates(&self) -> Option<Certificates> {
|
||||
self.io().and_then(|io| io.peer_certificates())
|
||||
}
|
||||
|
||||
fn enable_nodelay(&self) -> io::Result<()> {
|
||||
match self.io() {
|
||||
Some(io) => io.enable_nodelay(),
|
||||
None => Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub struct CancellableListener<F, L> {
|
||||
pub trigger: F,
|
||||
#[pin]
|
||||
pub listener: L,
|
||||
pub grace: Duration,
|
||||
pub mercy: Duration,
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, L> CancellableListener<F, L> {
|
||||
pub fn new(trigger: F, listener: L, grace: u64, mercy: u64) -> Self {
|
||||
let (grace, mercy) = (Duration::from_secs(grace), Duration::from_secs(mercy));
|
||||
CancellableListener { trigger, listener, grace, mercy }
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: Listener, F: Future + Clone> Listener for CancellableListener<F, L> {
|
||||
type Connection = CancellableIo<F, L::Connection>;
|
||||
|
||||
fn local_addr(&self) -> Option<std::net::SocketAddr> {
|
||||
self.listener.local_addr()
|
||||
}
|
||||
|
||||
fn poll_accept(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<io::Result<Self::Connection>> {
|
||||
self.as_mut().project().listener
|
||||
.poll_accept(cx)
|
||||
.map(|res| res.map(|conn| {
|
||||
CancellableIo::new(self.trigger.clone(), conn, self.grace, self.mercy)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait StreamExt: Sized + Stream {
|
||||
fn join<U>(self, other: U) -> Join<Self, U>
|
||||
where U: Stream<Item = Self::Item>;
|
||||
}
|
||||
|
||||
impl<S: Stream> StreamExt for S {
|
||||
fn join<U>(self, other: U) -> Join<Self, U>
|
||||
where U: Stream<Item = Self::Item>
|
||||
{
|
||||
Join::new(self, other)
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// Stream returned by the [`join`](super::StreamExt::join) method.
|
||||
pub struct Join<T, U> {
|
||||
#[pin]
|
||||
a: T,
|
||||
#[pin]
|
||||
b: U,
|
||||
// When `true`, poll `a` first, otherwise, `poll` b`.
|
||||
toggle: bool,
|
||||
// Set when either `a` or `b` return `None`.
|
||||
done: bool,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> Join<T, U> {
|
||||
pub(super) fn new(a: T, b: U) -> Join<T, U>
|
||||
where T: Stream, U: Stream,
|
||||
{
|
||||
Join { a, b, toggle: false, done: false, }
|
||||
}
|
||||
|
||||
fn poll_next<A: Stream, B: Stream<Item = A::Item>>(
|
||||
first: Pin<&mut A>,
|
||||
second: Pin<&mut B>,
|
||||
done: &mut bool,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<A::Item>> {
|
||||
match first.poll_next(cx) {
|
||||
Poll::Ready(opt) => { *done = opt.is_none(); Poll::Ready(opt) }
|
||||
Poll::Pending => match second.poll_next(cx) {
|
||||
Poll::Ready(opt) => { *done = opt.is_none(); Poll::Ready(opt) }
|
||||
Poll::Pending => Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> Stream for Join<T, U>
|
||||
where T: Stream,
|
||||
U: Stream<Item = T::Item>,
|
||||
{
|
||||
type Item = T::Item;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T::Item>> {
|
||||
if self.done {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
|
||||
let me = self.project();
|
||||
*me.toggle = !*me.toggle;
|
||||
match *me.toggle {
|
||||
true => Self::poll_next(me.a, me.b, me.done, cx),
|
||||
false => Self::poll_next(me.b, me.a, me.done, cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
let (left_low, left_high) = self.a.size_hint();
|
||||
let (right_low, right_high) = self.b.size_hint();
|
||||
|
||||
let low = left_low.saturating_add(right_low);
|
||||
let high = match (left_high, right_high) {
|
||||
(Some(h1), Some(h2)) => h1.checked_add(h2),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
(low, high)
|
||||
}
|
||||
}
|
|
@ -341,6 +341,7 @@
|
|||
// `key_contexts: Vec<K::Context>`, a vector of `value_contexts:
|
||||
// Vec<V::Context>`, a `mapping` from a string index to an integer index
|
||||
// into the `contexts`, and a vector of `errors`.
|
||||
//
|
||||
// 2. **Push.** An index is required; an error is emitted and `push` returns
|
||||
// if they field's first key does not contain an index. If the first key
|
||||
// contains _one_ index, a new `K::Context` and `V::Context` are created.
|
||||
|
@ -356,9 +357,9 @@
|
|||
// to `second` in `mapping`. If the first index is `k`, the field,
|
||||
// stripped of the first key, is pushed to the key's context; the same is
|
||||
// done for the value's context is the first index is `v`.
|
||||
//
|
||||
// 3. **Finalization.** Every context is finalized; errors and `Ok` values
|
||||
// are collected. TODO: FINISH. Split this into two: one for single-index,
|
||||
// another for two-indices.
|
||||
// are collected.
|
||||
|
||||
mod field;
|
||||
mod options;
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::io;
|
|||
use std::path::{Path, PathBuf};
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use tokio::fs::File;
|
||||
use tokio::fs::{File, OpenOptions};
|
||||
|
||||
use crate::request::Request;
|
||||
use crate::response::{self, Responder};
|
||||
|
@ -60,7 +60,7 @@ impl NamedFile {
|
|||
/// }
|
||||
/// ```
|
||||
pub async fn open<P: AsRef<Path>>(path: P) -> io::Result<NamedFile> {
|
||||
// FIXME: Grab the file size here and prohibit `seek`ing later (or else
|
||||
// TODO: Grab the file size here and prohibit `seek`ing later (or else
|
||||
// the file's effective size may change), to save on the cost of doing
|
||||
// all of those `seek`s to determine the file size. But, what happens if
|
||||
// the file gets changed between now and then?
|
||||
|
@ -68,6 +68,11 @@ impl NamedFile {
|
|||
Ok(NamedFile(path.as_ref().to_path_buf(), file))
|
||||
}
|
||||
|
||||
pub async fn open_with<P: AsRef<Path>>(path: P, opts: &OpenOptions) -> io::Result<NamedFile> {
|
||||
let file = opts.open(path.as_ref()).await?;
|
||||
Ok(NamedFile(path.as_ref().to_path_buf(), file))
|
||||
}
|
||||
|
||||
/// Retrieve the underlying `File`.
|
||||
///
|
||||
/// # Example
|
||||
|
|
|
@ -2,11 +2,10 @@ use std::fmt;
|
|||
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use crate::http::private::cookie;
|
||||
use crate::{Rocket, Orbit};
|
||||
|
||||
#[doc(inline)]
|
||||
pub use self::cookie::{Cookie, SameSite, Iter};
|
||||
pub use cookie::{Cookie, SameSite, Iter};
|
||||
|
||||
/// Collection of one or more HTTP cookies.
|
||||
///
|
||||
|
@ -167,7 +166,7 @@ pub(crate) struct CookieState<'a> {
|
|||
#[derive(Clone)]
|
||||
enum Op {
|
||||
Add(Cookie<'static>, bool),
|
||||
Remove(Cookie<'static>, bool),
|
||||
Remove(Cookie<'static>),
|
||||
}
|
||||
|
||||
impl<'a> CookieJar<'a> {
|
||||
|
@ -177,7 +176,7 @@ impl<'a> CookieJar<'a> {
|
|||
ops: Mutex::new(Vec::new()),
|
||||
state: CookieState {
|
||||
// This is updated dynamically when headers are received.
|
||||
secure: rocket.config().tls_enabled(),
|
||||
secure: rocket.endpoint().is_tls(),
|
||||
config: rocket.config(),
|
||||
}
|
||||
}
|
||||
|
@ -256,7 +255,7 @@ impl<'a> CookieJar<'a> {
|
|||
for op in ops.iter().rev().filter(|op| op.cookie().name() == name) {
|
||||
match op {
|
||||
Op::Add(c, _) => return Some(c.clone()),
|
||||
Op::Remove(_, _) => return None,
|
||||
Op::Remove(_) => return None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -389,7 +388,7 @@ impl<'a> CookieJar<'a> {
|
|||
pub fn remove<C: Into<Cookie<'static>>>(&self, cookie: C) {
|
||||
let mut cookie = cookie.into();
|
||||
Self::set_removal_defaults(&mut cookie);
|
||||
self.ops.lock().push(Op::Remove(cookie, false));
|
||||
self.ops.lock().push(Op::Remove(cookie));
|
||||
}
|
||||
|
||||
/// Removes the private `cookie` from the collection.
|
||||
|
@ -432,7 +431,7 @@ impl<'a> CookieJar<'a> {
|
|||
pub fn remove_private<C: Into<Cookie<'static>>>(&self, cookie: C) {
|
||||
let mut cookie = cookie.into();
|
||||
Self::set_removal_defaults(&mut cookie);
|
||||
self.ops.lock().push(Op::Remove(cookie, true));
|
||||
self.ops.lock().push(Op::Remove(cookie));
|
||||
}
|
||||
|
||||
/// Returns an iterator over all of the _original_ cookies present in this
|
||||
|
@ -477,7 +476,7 @@ impl<'a> CookieJar<'a> {
|
|||
Op::Add(c, true) => {
|
||||
jar.private_mut(&self.state.config.secret_key.key).add(c);
|
||||
}
|
||||
Op::Remove(mut c, _) => {
|
||||
Op::Remove(mut c) => {
|
||||
if self.jar.get(c.name()).is_some() {
|
||||
c.make_removal();
|
||||
jar.add(c);
|
||||
|
@ -595,7 +594,7 @@ impl<'a> Clone for CookieJar<'a> {
|
|||
impl Op {
|
||||
fn cookie(&self) -> &Cookie<'static> {
|
||||
match self {
|
||||
Op::Add(c, _) | Op::Remove(c, _) => c
|
||||
Op::Add(c, _) | Op::Remove(c) => c
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
//! Types that map to concepts in HTTP.
|
||||
//!
|
||||
//! This module exports types that map to HTTP concepts or to the underlying
|
||||
//! HTTP library when needed.
|
||||
|
||||
mod cookies;
|
||||
|
||||
#[doc(inline)]
|
||||
pub use rocket_http::*;
|
||||
|
||||
#[doc(inline)]
|
||||
pub use cookies::*;
|
|
@ -7,7 +7,9 @@
|
|||
#![cfg_attr(nightly, feature(decl_macro))]
|
||||
|
||||
#![warn(rust_2018_idioms)]
|
||||
#![warn(missing_docs)]
|
||||
// #![warn(missing_docs)]
|
||||
#![allow(async_fn_in_trait)]
|
||||
#![allow(refining_impl_trait)]
|
||||
|
||||
//! # Rocket - Core API Documentation
|
||||
//!
|
||||
|
@ -109,18 +111,24 @@
|
|||
|
||||
/// These are public dependencies! Update docs if these are changed, especially
|
||||
/// figment's version number in docs.
|
||||
#[doc(hidden)] pub use yansi;
|
||||
#[doc(hidden)] pub use async_stream;
|
||||
#[doc(hidden)]
|
||||
pub use yansi;
|
||||
#[doc(hidden)]
|
||||
pub use async_stream;
|
||||
pub use futures;
|
||||
pub use tokio;
|
||||
pub use figment;
|
||||
pub use time;
|
||||
|
||||
#[doc(hidden)]
|
||||
#[macro_use] pub mod log;
|
||||
#[macro_use] pub mod outcome;
|
||||
#[macro_use] pub mod data;
|
||||
#[doc(hidden)] pub mod sentinel;
|
||||
#[macro_use]
|
||||
pub mod log;
|
||||
#[macro_use]
|
||||
pub mod outcome;
|
||||
#[macro_use]
|
||||
pub mod data;
|
||||
#[doc(hidden)]
|
||||
pub mod sentinel;
|
||||
pub mod local;
|
||||
pub mod request;
|
||||
pub mod response;
|
||||
|
@ -133,74 +141,41 @@ pub mod route;
|
|||
pub mod serde;
|
||||
pub mod shield;
|
||||
pub mod fs;
|
||||
|
||||
// Reexport of HTTP everything.
|
||||
pub mod http {
|
||||
//! Types that map to concepts in HTTP.
|
||||
//!
|
||||
//! This module exports types that map to HTTP concepts or to the underlying
|
||||
//! HTTP library when needed.
|
||||
|
||||
#[doc(inline)]
|
||||
pub use rocket_http::*;
|
||||
|
||||
/// Re-exported hyper HTTP library types.
|
||||
///
|
||||
/// All types that are re-exported from Hyper reside inside of this module.
|
||||
/// These types will, with certainty, be removed with time, but they reside here
|
||||
/// while necessary.
|
||||
pub mod hyper {
|
||||
#[doc(hidden)]
|
||||
pub use rocket_http::hyper::*;
|
||||
|
||||
pub use rocket_http::hyper::header;
|
||||
}
|
||||
|
||||
#[doc(inline)]
|
||||
pub use crate::cookies::*;
|
||||
}
|
||||
|
||||
pub mod http;
|
||||
pub mod listener;
|
||||
#[cfg(feature = "tls")]
|
||||
#[cfg_attr(nightly, doc(cfg(feature = "tls")))]
|
||||
pub mod tls;
|
||||
#[cfg(feature = "mtls")]
|
||||
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
|
||||
pub mod mtls;
|
||||
|
||||
/// TODO: We need a futures mod or something.
|
||||
mod trip_wire;
|
||||
mod util;
|
||||
mod shutdown;
|
||||
mod server;
|
||||
mod ext;
|
||||
mod lifecycle;
|
||||
mod state;
|
||||
mod cookies;
|
||||
mod rocket;
|
||||
mod router;
|
||||
mod phase;
|
||||
mod erased;
|
||||
|
||||
#[doc(hidden)] pub use either::Either;
|
||||
|
||||
#[doc(inline)] pub use rocket_codegen::*;
|
||||
|
||||
#[doc(inline)] pub use crate::response::Response;
|
||||
#[doc(inline)] pub use crate::data::Data;
|
||||
#[doc(inline)] pub use crate::config::Config;
|
||||
#[doc(inline)] pub use crate::catcher::Catcher;
|
||||
#[doc(inline)] pub use crate::route::Route;
|
||||
#[doc(hidden)] pub use either::Either;
|
||||
#[doc(inline)] pub use phase::{Phase, Build, Ignite, Orbit};
|
||||
#[doc(inline)] pub use error::Error;
|
||||
#[doc(inline)] pub use sentinel::Sentinel;
|
||||
#[doc(inline)] pub use crate::phase::{Phase, Build, Ignite, Orbit};
|
||||
#[doc(inline)] pub use crate::error::Error;
|
||||
#[doc(inline)] pub use crate::sentinel::Sentinel;
|
||||
#[doc(inline)] pub use crate::request::Request;
|
||||
#[doc(inline)] pub use crate::rocket::Rocket;
|
||||
#[doc(inline)] pub use crate::shutdown::Shutdown;
|
||||
#[doc(inline)] pub use crate::state::State;
|
||||
#[doc(inline)] pub use rocket_codegen::*;
|
||||
|
||||
/// Creates a [`Rocket`] instance with the default config provider: aliases
|
||||
/// [`Rocket::build()`].
|
||||
pub fn build() -> Rocket<Build> {
|
||||
Rocket::build()
|
||||
}
|
||||
|
||||
/// Creates a [`Rocket`] instance with a custom config provider: aliases
|
||||
/// [`Rocket::custom()`].
|
||||
pub fn custom<T: figment::Provider>(provider: T) -> Rocket<Build> {
|
||||
Rocket::custom(provider)
|
||||
}
|
||||
|
||||
/// Retrofits support for `async fn` in trait impls and declarations.
|
||||
///
|
||||
|
@ -231,6 +206,20 @@ pub fn custom<T: figment::Provider>(provider: T) -> Rocket<Build> {
|
|||
#[doc(inline)]
|
||||
pub use async_trait::async_trait;
|
||||
|
||||
const WORKER_PREFIX: &'static str = "rocket-worker";
|
||||
|
||||
/// Creates a [`Rocket`] instance with the default config provider: aliases
|
||||
/// [`Rocket::build()`].
|
||||
pub fn build() -> Rocket<Build> {
|
||||
Rocket::build()
|
||||
}
|
||||
|
||||
/// Creates a [`Rocket`] instance with a custom config provider: aliases
|
||||
/// [`Rocket::custom()`].
|
||||
pub fn custom<T: figment::Provider>(provider: T) -> Rocket<Build> {
|
||||
Rocket::custom(provider)
|
||||
}
|
||||
|
||||
/// WARNING: This is unstable! Do not use this method outside of Rocket!
|
||||
#[doc(hidden)]
|
||||
pub fn async_run<F, R>(fut: F, workers: usize, sync: usize, force_end: bool, name: &str) -> R
|
||||
|
@ -255,7 +244,7 @@ pub fn async_run<F, R>(fut: F, workers: usize, sync: usize, force_end: bool, nam
|
|||
/// WARNING: This is unstable! Do not use this method outside of Rocket!
|
||||
#[doc(hidden)]
|
||||
pub fn async_test<R>(fut: impl std::future::Future<Output = R>) -> R {
|
||||
async_run(fut, 1, 32, true, "rocket-worker-test-thread")
|
||||
async_run(fut, 1, 32, true, &format!("{WORKER_PREFIX}-test-thread"))
|
||||
}
|
||||
|
||||
/// WARNING: This is unstable! Do not use this method outside of Rocket!
|
||||
|
@ -276,7 +265,7 @@ pub fn async_main<R>(fut: impl std::future::Future<Output = R> + Send) -> R {
|
|||
let workers = fig.extract_inner(Config::WORKERS).unwrap_or_else(bail);
|
||||
let max_blocking = fig.extract_inner(Config::MAX_BLOCKING).unwrap_or_else(bail);
|
||||
let force = fig.focus(Config::SHUTDOWN).extract_inner("force").unwrap_or_else(bail);
|
||||
async_run(fut, workers, max_blocking, force, "rocket-worker-thread")
|
||||
async_run(fut, workers, max_blocking, force, &format!("{WORKER_PREFIX}-thread"))
|
||||
}
|
||||
|
||||
/// Executes a `future` to completion on a new tokio-based Rocket async runtime.
|
||||
|
@ -359,3 +348,14 @@ pub fn execute<R, F>(future: F) -> R
|
|||
{
|
||||
async_main(future)
|
||||
}
|
||||
|
||||
/// Returns a future that evalutes to `true` exactly when there is a presently
|
||||
/// running tokio async runtime that was likely started by Rocket.
|
||||
fn running_within_rocket_async_rt() -> impl std::future::Future<Output = bool> {
|
||||
use futures::FutureExt;
|
||||
|
||||
tokio::task::spawn_blocking(|| {
|
||||
let this = std::thread::current();
|
||||
this.name().map_or(false, |s| s.starts_with(WORKER_PREFIX))
|
||||
}).map(|r| r.unwrap_or(false))
|
||||
}
|
||||
|
|
|
@ -0,0 +1,272 @@
|
|||
use yansi::Paint;
|
||||
use futures::future::{FutureExt, Future};
|
||||
|
||||
use crate::{route, Rocket, Orbit, Request, Response, Data};
|
||||
use crate::data::IoHandler;
|
||||
use crate::http::{Method, Status, Header};
|
||||
use crate::outcome::Outcome;
|
||||
use crate::form::Form;
|
||||
|
||||
// A token returned to force the execution of one method before another.
|
||||
pub(crate) struct RequestToken;
|
||||
|
||||
async fn catch_handle<Fut, T, F>(name: Option<&str>, run: F) -> Option<T>
|
||||
where F: FnOnce() -> Fut, Fut: Future<Output = T>,
|
||||
{
|
||||
macro_rules! panic_info {
|
||||
($name:expr, $e:expr) => {{
|
||||
match $name {
|
||||
Some(name) => error_!("Handler {} panicked.", name.primary()),
|
||||
None => error_!("A handler panicked.")
|
||||
};
|
||||
|
||||
info_!("This is an application bug.");
|
||||
info_!("A panic in Rust must be treated as an exceptional event.");
|
||||
info_!("Panicking is not a suitable error handling mechanism.");
|
||||
info_!("Unwinding, the result of a panic, is an expensive operation.");
|
||||
info_!("Panics will degrade application performance.");
|
||||
info_!("Instead of panicking, return `Option` and/or `Result`.");
|
||||
info_!("Values of either type can be returned directly from handlers.");
|
||||
warn_!("A panic is treated as an internal server error.");
|
||||
$e
|
||||
}}
|
||||
}
|
||||
|
||||
let run = std::panic::AssertUnwindSafe(run);
|
||||
let fut = std::panic::catch_unwind(move || run())
|
||||
.map_err(|e| panic_info!(name, e))
|
||||
.ok()?;
|
||||
|
||||
std::panic::AssertUnwindSafe(fut)
|
||||
.catch_unwind()
|
||||
.await
|
||||
.map_err(|e| panic_info!(name, e))
|
||||
.ok()
|
||||
}
|
||||
|
||||
impl Rocket<Orbit> {
|
||||
/// Preprocess the request for Rocket things. Currently, this means:
|
||||
///
|
||||
/// * Rewriting the method in the request if _method form field exists.
|
||||
/// * Run the request fairings.
|
||||
///
|
||||
/// This is the only place during lifecycle processing that `Request` is
|
||||
/// mutable. Keep this in-sync with the `FromForm` derive.
|
||||
pub(crate) async fn preprocess(
|
||||
&self,
|
||||
req: &mut Request<'_>,
|
||||
data: &mut Data<'_>
|
||||
) -> RequestToken {
|
||||
// Check if this is a form and if the form contains the special _method
|
||||
// field which we use to reinterpret the request's method.
|
||||
let (min_len, max_len) = ("_method=get".len(), "_method=delete".len());
|
||||
let peek_buffer = data.peek(max_len).await;
|
||||
let is_form = req.content_type().map_or(false, |ct| ct.is_form());
|
||||
|
||||
if is_form && req.method() == Method::Post && peek_buffer.len() >= min_len {
|
||||
let method = std::str::from_utf8(peek_buffer).ok()
|
||||
.and_then(|raw_form| Form::values(raw_form).next())
|
||||
.filter(|field| field.name == "_method")
|
||||
.and_then(|field| field.value.parse().ok());
|
||||
|
||||
if let Some(method) = method {
|
||||
req.set_method(method);
|
||||
}
|
||||
}
|
||||
|
||||
// Run request fairings.
|
||||
self.fairings.handle_request(req, data).await;
|
||||
|
||||
RequestToken
|
||||
}
|
||||
|
||||
/// Dispatches the request to the router and processes the outcome to
|
||||
/// produce a response. If the initial outcome is a *forward* and the
|
||||
/// request was a HEAD request, the request is rewritten and rerouted as a
|
||||
/// GET. This is automatic HEAD handling.
|
||||
///
|
||||
/// After performing the above, if the outcome is a forward or error, the
|
||||
/// appropriate error catcher is invoked to produce the response. Otherwise,
|
||||
/// the successful response is used directly.
|
||||
///
|
||||
/// Finally, new cookies in the cookie jar are added to the response,
|
||||
/// Rocket-specific headers are written, and response fairings are run. Note
|
||||
/// that error responses have special cookie handling. See `handle_error`.
|
||||
pub(crate) async fn dispatch<'r, 's: 'r>(
|
||||
&'s self,
|
||||
_token: RequestToken,
|
||||
request: &'r Request<'s>,
|
||||
data: Data<'r>,
|
||||
// io_stream: impl Future<Output = io::Result<IoStream>> + Send,
|
||||
) -> Response<'r> {
|
||||
info!("{}:", request);
|
||||
|
||||
// Remember if the request is `HEAD` for later body stripping.
|
||||
let was_head_request = request.method() == Method::Head;
|
||||
|
||||
// Route the request and run the user's handlers.
|
||||
let mut response = match self.route(request, data).await {
|
||||
Outcome::Success(response) => response,
|
||||
Outcome::Forward((data, _)) if request.method() == Method::Head => {
|
||||
info_!("Autohandling {} request.", "HEAD".primary().bold());
|
||||
|
||||
// Dispatch the request again with Method `GET`.
|
||||
request._set_method(Method::Get);
|
||||
match self.route(request, data).await {
|
||||
Outcome::Success(response) => response,
|
||||
Outcome::Error(status) => self.dispatch_error(status, request).await,
|
||||
Outcome::Forward((_, status)) => self.dispatch_error(status, request).await,
|
||||
}
|
||||
}
|
||||
Outcome::Forward((_, status)) => self.dispatch_error(status, request).await,
|
||||
Outcome::Error(status) => self.dispatch_error(status, request).await,
|
||||
};
|
||||
|
||||
// Set the cookies. Note that error responses will only include cookies
|
||||
// set by the error handler. See `handle_error` for more.
|
||||
let delta_jar = request.cookies().take_delta_jar();
|
||||
for cookie in delta_jar.delta() {
|
||||
response.adjoin_header(cookie);
|
||||
}
|
||||
|
||||
// Add a default 'Server' header if it isn't already there.
|
||||
// TODO: If removing Hyper, write out `Date` header too.
|
||||
if let Some(ident) = request.rocket().config.ident.as_str() {
|
||||
if !response.headers().contains("Server") {
|
||||
response.set_header(Header::new("Server", ident));
|
||||
}
|
||||
}
|
||||
|
||||
// Run the response fairings.
|
||||
self.fairings.handle_response(request, &mut response).await;
|
||||
|
||||
// Strip the body if this is a `HEAD` request.
|
||||
if was_head_request {
|
||||
response.strip_body();
|
||||
}
|
||||
|
||||
// TODO: Should upgrades be handled here? We miss them on local clients.
|
||||
response
|
||||
}
|
||||
|
||||
pub(crate) fn extract_io_handler<'r>(
|
||||
request: &Request<'_>,
|
||||
response: &mut Response<'r>,
|
||||
// io_stream: impl Future<Output = io::Result<IoStream>> + Send,
|
||||
) -> Option<Box<dyn IoHandler + 'r>> {
|
||||
let upgrades = request.headers().get("upgrade");
|
||||
let Ok(upgrade) = response.search_upgrades(upgrades) else {
|
||||
warn_!("Request wants upgrade but no I/O handler matched.");
|
||||
info_!("Request is not being upgraded.");
|
||||
return None;
|
||||
};
|
||||
|
||||
if let Some((proto, io_handler)) = upgrade {
|
||||
info_!("Attemping upgrade with {proto} I/O handler.");
|
||||
response.set_status(Status::SwitchingProtocols);
|
||||
response.set_raw_header("Connection", "Upgrade");
|
||||
response.set_raw_header("Upgrade", proto.to_string());
|
||||
return Some(io_handler);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Calls the handler for each matching route until one of the handlers
|
||||
/// returns success or error, or there are no additional routes to try, in
|
||||
/// which case a `Forward` with the last forwarding state is returned.
|
||||
#[inline]
|
||||
async fn route<'s, 'r: 's>(
|
||||
&'s self,
|
||||
request: &'r Request<'s>,
|
||||
mut data: Data<'r>,
|
||||
) -> route::Outcome<'r> {
|
||||
// Go through all matching routes until we fail or succeed or run out of
|
||||
// routes to try, in which case we forward with the last status.
|
||||
let mut status = Status::NotFound;
|
||||
for route in self.router.route(request) {
|
||||
// Retrieve and set the requests parameters.
|
||||
info_!("Matched: {}", route);
|
||||
request.set_route(route);
|
||||
|
||||
let name = route.name.as_deref();
|
||||
let outcome = catch_handle(name, || route.handler.handle(request, data)).await
|
||||
.unwrap_or(Outcome::Error(Status::InternalServerError));
|
||||
|
||||
// Check if the request processing completed (Some) or if the
|
||||
// request needs to be forwarded. If it does, continue the loop
|
||||
// (None) to try again.
|
||||
info_!("{}", outcome.log_display());
|
||||
match outcome {
|
||||
o@Outcome::Success(_) | o@Outcome::Error(_) => return o,
|
||||
Outcome::Forward(forwarded) => (data, status) = forwarded,
|
||||
}
|
||||
}
|
||||
|
||||
error_!("No matching routes for {}.", request);
|
||||
Outcome::Forward((data, status))
|
||||
}
|
||||
|
||||
// Invokes the catcher for `status`. Returns the response on success.
|
||||
//
|
||||
// Resets the cookie jar delta state to prevent any modifications from
|
||||
// earlier unsuccessful paths from being reflected in the error response.
|
||||
//
|
||||
// On catcher error, the 500 error catcher is attempted. If _that_ errors,
|
||||
// the (infallible) default 500 error cather is used.
|
||||
pub(crate) async fn dispatch_error<'r, 's: 'r>(
|
||||
&'s self,
|
||||
mut status: Status,
|
||||
req: &'r Request<'s>
|
||||
) -> Response<'r> {
|
||||
// We may wish to relax this in the future.
|
||||
req.cookies().reset_delta();
|
||||
|
||||
// Dispatch to the `status` catcher.
|
||||
if let Ok(r) = self.invoke_catcher(status, req).await {
|
||||
return r;
|
||||
}
|
||||
|
||||
// If it fails and it's not a 500, try the 500 catcher.
|
||||
if status != Status::InternalServerError {
|
||||
error_!("Catcher failed. Attempting 500 error catcher.");
|
||||
status = Status::InternalServerError;
|
||||
if let Ok(r) = self.invoke_catcher(status, req).await {
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
// If it failed again or if it was already a 500, use Rocket's default.
|
||||
error_!("{} catcher failed. Using Rocket default 500.", status.code);
|
||||
crate::catcher::default_handler(Status::InternalServerError, req)
|
||||
}
|
||||
|
||||
/// Invokes the handler with `req` for catcher with status `status`.
|
||||
///
|
||||
/// In order of preference, invoked handler is:
|
||||
/// * the user's registered handler for `status`
|
||||
/// * the user's registered `default` handler
|
||||
/// * Rocket's default handler for `status`
|
||||
///
|
||||
/// Return `Ok(result)` if the handler succeeded. Returns `Ok(Some(Status))`
|
||||
/// if the handler ran to completion but failed. Returns `Ok(None)` if the
|
||||
/// handler panicked while executing.
|
||||
async fn invoke_catcher<'s, 'r: 's>(
|
||||
&'s self,
|
||||
status: Status,
|
||||
req: &'r Request<'s>
|
||||
) -> Result<Response<'r>, Option<Status>> {
|
||||
if let Some(catcher) = self.router.catch(status, req) {
|
||||
warn_!("Responding with registered {} catcher.", catcher);
|
||||
let name = catcher.name.as_deref();
|
||||
catch_handle(name, || catcher.handler.handle(status, req)).await
|
||||
.map(|result| result.map_err(Some))
|
||||
.unwrap_or_else(|| Err(None))
|
||||
} else {
|
||||
let code = status.code.blue().bold();
|
||||
warn_!("No {} catcher registered. Using Rocket default.", code);
|
||||
Ok(crate::catcher::default_handler(status, req))
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
use futures::TryFutureExt;
|
||||
|
||||
use crate::listener::Listener;
|
||||
|
||||
pub trait Bindable: Sized {
|
||||
type Listener: Listener + 'static;
|
||||
|
||||
type Error: std::error::Error + Send + 'static;
|
||||
|
||||
async fn bind(self) -> Result<Self::Listener, Self::Error>;
|
||||
}
|
||||
|
||||
impl<L: Listener + 'static> Bindable for L {
|
||||
type Listener = L;
|
||||
|
||||
type Error = std::convert::Infallible;
|
||||
|
||||
async fn bind(self) -> Result<Self::Listener, Self::Error> {
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Bindable, B: Bindable> Bindable for either::Either<A, B> {
|
||||
type Listener = tokio_util::either::Either<A::Listener, B::Listener>;
|
||||
|
||||
type Error = either::Either<A::Error, B::Error>;
|
||||
|
||||
async fn bind(self) -> Result<Self::Listener, Self::Error> {
|
||||
match self {
|
||||
either::Either::Left(a) => a.bind()
|
||||
.map_ok(tokio_util::either::Either::Left)
|
||||
.map_err(either::Either::Left)
|
||||
.await,
|
||||
either::Either::Right(b) => b.bind()
|
||||
.map_ok(tokio_util::either::Either::Right)
|
||||
.map_err(either::Either::Right)
|
||||
.await,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
use std::{io, time::Duration};
|
||||
|
||||
use crate::listener::{Listener, Endpoint};
|
||||
|
||||
static DURATION: Duration = Duration::from_millis(250);
|
||||
|
||||
pub struct Bounced<L> {
|
||||
listener: L,
|
||||
}
|
||||
|
||||
pub trait BouncedExt: Sized {
|
||||
fn bounced(self) -> Bounced<Self> {
|
||||
Bounced { listener: self }
|
||||
}
|
||||
}
|
||||
|
||||
impl<L> BouncedExt for L { }
|
||||
|
||||
fn is_recoverable(e: &io::Error) -> bool {
|
||||
matches!(e.kind(),
|
||||
| io::ErrorKind::ConnectionRefused
|
||||
| io::ErrorKind::ConnectionAborted
|
||||
| io::ErrorKind::ConnectionReset)
|
||||
}
|
||||
|
||||
impl<L: Listener + Sync> Bounced<L> {
|
||||
#[inline]
|
||||
pub async fn accept_next(&self) -> <Self as Listener>::Accept {
|
||||
loop {
|
||||
match self.listener.accept().await {
|
||||
Ok(accept) => return accept,
|
||||
Err(e) if is_recoverable(&e) => warn!("recoverable connection error: {e}"),
|
||||
Err(e) => {
|
||||
warn!("accept error: {e} [retrying in {}ms]", DURATION.as_millis());
|
||||
tokio::time::sleep(DURATION).await;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: Listener + Sync> Listener for Bounced<L> {
|
||||
type Accept = L::Accept;
|
||||
|
||||
type Connection = L::Connection;
|
||||
|
||||
async fn accept(&self) -> io::Result<Self::Accept> {
|
||||
Ok(self.accept_next().await)
|
||||
}
|
||||
|
||||
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> {
|
||||
self.listener.connect(accept).await
|
||||
}
|
||||
|
||||
fn socket_addr(&self) -> io::Result<Endpoint> {
|
||||
self.listener.socket_addr()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,273 @@
|
|||
use std::io;
|
||||
use std::time::Duration;
|
||||
use std::task::{Poll, Context};
|
||||
use std::pin::Pin;
|
||||
|
||||
use tokio::time::{sleep, Sleep};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use futures::{StreamExt, future::{select, Either, Fuse, Future, FutureExt}};
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
use crate::{config, Shutdown};
|
||||
use crate::listener::{Listener, Connection, Certificates, Bounced, Endpoint};
|
||||
|
||||
// Rocket wraps all connections in a `CancellableIo` struct, an internal
|
||||
// structure that gracefully closes I/O when it receives a signal. That signal
|
||||
// is the `shutdown` future. When the future resolves, `CancellableIo` begins to
|
||||
// terminate in grace, mercy, and finally force close phases. Since all
|
||||
// connections are wrapped in `CancellableIo`, this eventually ends all I/O.
|
||||
//
|
||||
// At that point, unless a user spawned an infinite, stand-alone task that isn't
|
||||
// monitoring `Shutdown`, all tasks should resolve. This means that all
|
||||
// instances of the shared `Arc<Rocket>` are dropped and we can return the owned
|
||||
// instance of `Rocket`.
|
||||
//
|
||||
// Unfortunately, the Hyper `server` future resolves as soon as it has finished
|
||||
// processing requests without respect for ongoing responses. That is, `server`
|
||||
// resolves even when there are running tasks that are generating a response.
|
||||
// So, `server` resolving implies little to nothing about the state of
|
||||
// connections. As a result, we depend on the timing of grace + mercy + some
|
||||
// buffer to determine when all connections should be closed, thus all tasks
|
||||
// should be complete, thus all references to `Arc<Rocket>` should be dropped
|
||||
// and we can get a unique reference.
|
||||
pin_project! {
|
||||
pub struct CancellableListener<F, L> {
|
||||
pub trigger: F,
|
||||
#[pin]
|
||||
pub listener: L,
|
||||
pub grace: Duration,
|
||||
pub mercy: Duration,
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// I/O that can be cancelled when a future `F` resolves.
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct CancellableIo<F, I> {
|
||||
#[pin]
|
||||
io: Option<I>,
|
||||
#[pin]
|
||||
trigger: Fuse<F>,
|
||||
state: State,
|
||||
grace: Duration,
|
||||
mercy: Duration,
|
||||
}
|
||||
}
|
||||
|
||||
enum State {
|
||||
/// I/O has not been cancelled. Proceed as normal.
|
||||
Active,
|
||||
/// I/O has been cancelled. See if we can finish before the timer expires.
|
||||
Grace(Pin<Box<Sleep>>),
|
||||
/// Grace period elapsed. Shutdown the connection, waiting for the timer
|
||||
/// until we force close.
|
||||
Mercy(Pin<Box<Sleep>>),
|
||||
}
|
||||
|
||||
pub trait CancellableExt: Sized {
|
||||
fn cancellable(
|
||||
self,
|
||||
trigger: Shutdown,
|
||||
config: &config::Shutdown
|
||||
) -> CancellableListener<Shutdown, Self> {
|
||||
if let Some(mut stream) = config.signal_stream() {
|
||||
let trigger = trigger.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Some(sig) = stream.next().await {
|
||||
if trigger.0.tripped() {
|
||||
warn!("Received {}. Shutdown already in progress.", sig);
|
||||
} else {
|
||||
warn!("Received {}. Requesting shutdown.", sig);
|
||||
}
|
||||
|
||||
trigger.0.trip();
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
CancellableListener {
|
||||
trigger,
|
||||
listener: self,
|
||||
grace: config.grace(),
|
||||
mercy: config.mercy(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: Listener> CancellableExt for L { }
|
||||
|
||||
fn time_out() -> io::Error {
|
||||
io::Error::new(io::ErrorKind::TimedOut, "Shutdown grace timed out")
|
||||
}
|
||||
|
||||
fn gone() -> io::Error {
|
||||
io::Error::new(io::ErrorKind::BrokenPipe, "IO driver has terminated")
|
||||
}
|
||||
|
||||
impl<L, F> CancellableListener<F, Bounced<L>>
|
||||
where L: Listener + Sync,
|
||||
F: Future + Unpin + Clone + Send + Sync + 'static
|
||||
{
|
||||
pub async fn accept_next(&self) -> Option<<Self as Listener>::Accept> {
|
||||
let next = std::pin::pin!(self.listener.accept_next());
|
||||
match select(next, self.trigger.clone()).await {
|
||||
Either::Left((next, _)) => Some(next),
|
||||
Either::Right(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<L, F> CancellableListener<F, L>
|
||||
where L: Listener + Sync,
|
||||
F: Future + Clone + Send + Sync + 'static
|
||||
{
|
||||
fn io<C>(&self, conn: C) -> CancellableIo<F, C> {
|
||||
CancellableIo {
|
||||
io: Some(conn),
|
||||
trigger: self.trigger.clone().fuse(),
|
||||
state: State::Active,
|
||||
grace: self.grace,
|
||||
mercy: self.mercy,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<L, F> Listener for CancellableListener<F, L>
|
||||
where L: Listener + Sync,
|
||||
F: Future + Clone + Send + Sync + Unpin + 'static
|
||||
{
|
||||
type Accept = L::Accept;
|
||||
|
||||
type Connection = CancellableIo<F, L::Connection>;
|
||||
|
||||
async fn accept(&self) -> io::Result<Self::Accept> {
|
||||
let accept = std::pin::pin!(self.listener.accept());
|
||||
match select(accept, self.trigger.clone()).await {
|
||||
Either::Left((result, _)) => result,
|
||||
Either::Right(_) => Err(gone()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> {
|
||||
let conn = std::pin::pin!(self.listener.connect(accept));
|
||||
match select(conn, self.trigger.clone()).await {
|
||||
Either::Left((conn, _)) => Ok(self.io(conn?)),
|
||||
Either::Right(_) => Err(gone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn socket_addr(&self) -> io::Result<Endpoint> {
|
||||
self.listener.socket_addr()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future, I: AsyncWrite> CancellableIo<F, I> {
|
||||
fn inner(&self) -> Option<&I> {
|
||||
self.io.as_ref()
|
||||
}
|
||||
|
||||
/// Run `do_io` while connection processing should continue.
|
||||
fn poll_trigger_then<T>(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll<io::Result<T>>,
|
||||
) -> Poll<io::Result<T>> {
|
||||
let mut me = self.as_mut().project();
|
||||
let io = match me.io.as_pin_mut() {
|
||||
Some(io) => io,
|
||||
None => return Poll::Ready(Err(gone())),
|
||||
};
|
||||
|
||||
loop {
|
||||
match me.state {
|
||||
State::Active => {
|
||||
if me.trigger.as_mut().poll(cx).is_ready() {
|
||||
*me.state = State::Grace(Box::pin(sleep(*me.grace)));
|
||||
} else {
|
||||
return do_io(io, cx);
|
||||
}
|
||||
}
|
||||
State::Grace(timer) => {
|
||||
if timer.as_mut().poll(cx).is_ready() {
|
||||
*me.state = State::Mercy(Box::pin(sleep(*me.mercy)));
|
||||
} else {
|
||||
return do_io(io, cx);
|
||||
}
|
||||
}
|
||||
State::Mercy(timer) => {
|
||||
if timer.as_mut().poll(cx).is_ready() {
|
||||
self.project().io.set(None);
|
||||
return Poll::Ready(Err(time_out()));
|
||||
} else {
|
||||
let result = futures::ready!(io.poll_shutdown(cx));
|
||||
self.project().io.set(None);
|
||||
return match result {
|
||||
Err(e) => Poll::Ready(Err(e)),
|
||||
Ok(()) => Poll::Ready(Err(gone()))
|
||||
};
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future, I: AsyncRead + AsyncWrite> AsyncRead for CancellableIo<F, I> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_read(cx, buf))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future, I: AsyncWrite> AsyncWrite for CancellableIo<F, I> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write(cx, buf))
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_flush(cx))
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_shutdown(cx))
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[io::IoSlice<'_>],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write_vectored(cx, bufs))
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.inner().map(|io| io.is_write_vectored()).unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future, C: Connection> Connection for CancellableIo<F, C>
|
||||
where F: Unpin + Send + 'static
|
||||
{
|
||||
fn peer_address(&self) -> io::Result<Endpoint> {
|
||||
self.inner()
|
||||
.ok_or_else(|| gone())
|
||||
.and_then(|io| io.peer_address())
|
||||
}
|
||||
|
||||
fn peer_certificates(&self) -> Option<Certificates<'_>> {
|
||||
self.inner().and_then(|io| io.peer_certificates())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
use std::io;
|
||||
use std::borrow::Cow;
|
||||
|
||||
use tokio_util::either::Either;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use super::Endpoint;
|
||||
|
||||
/// A collection of raw certificate data.
|
||||
#[derive(Clone)]
|
||||
pub struct Certificates<'r>(Cow<'r, [der::CertificateDer<'r>]>);
|
||||
|
||||
pub trait Connection: AsyncRead + AsyncWrite + Send + Unpin {
|
||||
fn peer_address(&self) -> io::Result<Endpoint>;
|
||||
|
||||
/// DER-encoded X.509 certificate chain presented by the client, if any.
|
||||
///
|
||||
/// The certificate order must be as it appears in the TLS protocol: the
|
||||
/// first certificate relates to the peer, the second certifies the first,
|
||||
/// the third certifies the second, and so on.
|
||||
///
|
||||
/// Defaults to an empty vector to indicate that no certificates were
|
||||
/// presented.
|
||||
fn peer_certificates(&self) -> Option<Certificates<'_>> { None }
|
||||
}
|
||||
|
||||
impl<A: Connection, B: Connection> Connection for Either<A, B> {
|
||||
fn peer_address(&self) -> io::Result<Endpoint> {
|
||||
match self {
|
||||
Either::Left(c) => c.peer_address(),
|
||||
Either::Right(c) => c.peer_address(),
|
||||
}
|
||||
}
|
||||
|
||||
fn peer_certificates(&self) -> Option<Certificates<'_>> {
|
||||
match self {
|
||||
Either::Left(c) => c.peer_certificates(),
|
||||
Either::Right(c) => c.peer_certificates(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Certificates<'_> {
|
||||
pub fn into_owned(self) -> Certificates<'static> {
|
||||
let cow = self.0.into_iter()
|
||||
.map(|der| der.clone().into_owned())
|
||||
.collect::<Vec<_>>()
|
||||
.into();
|
||||
|
||||
Certificates(cow)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "mtls")]
|
||||
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
|
||||
mod der {
|
||||
use super::*;
|
||||
|
||||
pub use crate::mtls::CertificateDer;
|
||||
|
||||
impl<'r> Certificates<'r> {
|
||||
pub(crate) fn inner(&self) -> &[CertificateDer<'r>] {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r> From<&'r [CertificateDer<'r>]> for Certificates<'r> {
|
||||
fn from(value: &'r [CertificateDer<'r>]) -> Self {
|
||||
Certificates(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<CertificateDer<'static>>> for Certificates<'static> {
|
||||
fn from(value: Vec<CertificateDer<'static>>) -> Self {
|
||||
Certificates(value.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "mtls"))]
|
||||
mod der {
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
|
||||
#[derive(Clone)]
|
||||
pub struct CertificateDer<'r>(PhantomData<&'r [u8]>);
|
||||
|
||||
impl CertificateDer<'_> {
|
||||
pub fn into_owned(self) -> CertificateDer<'static> {
|
||||
CertificateDer(PhantomData)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
use either::Either;
|
||||
|
||||
use crate::listener::{Bindable, Endpoint};
|
||||
use crate::error::{Error, ErrorKind};
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct DefaultListener {
|
||||
#[serde(default)]
|
||||
pub address: Endpoint,
|
||||
pub port: Option<u16>,
|
||||
pub reuse: Option<bool>,
|
||||
#[cfg(feature = "tls")]
|
||||
pub tls: Option<crate::tls::TlsConfig>,
|
||||
}
|
||||
|
||||
#[cfg(not(unix))] type BaseBindable = Either<std::net::SocketAddr, std::net::SocketAddr>;
|
||||
#[cfg(unix)] type BaseBindable = Either<std::net::SocketAddr, super::unix::UdsConfig>;
|
||||
|
||||
#[cfg(not(feature = "tls"))] type TlsBindable<T> = Either<T, T>;
|
||||
#[cfg(feature = "tls")] type TlsBindable<T> = Either<super::tls::TlsBindable<T>, T>;
|
||||
|
||||
impl DefaultListener {
|
||||
pub(crate) fn base_bindable(&self) -> Result<BaseBindable, crate::Error> {
|
||||
match &self.address {
|
||||
Endpoint::Tcp(mut address) => {
|
||||
self.port.map(|port| address.set_port(port));
|
||||
Ok(BaseBindable::Left(address))
|
||||
},
|
||||
#[cfg(unix)]
|
||||
Endpoint::Unix(path) => {
|
||||
let uds = super::unix::UdsConfig { path: path.clone(), reuse: self.reuse, };
|
||||
Ok(BaseBindable::Right(uds))
|
||||
},
|
||||
#[cfg(not(unix))]
|
||||
Endpoint::Unix(_) => {
|
||||
let msg = "Unix domain sockets unavailable on non-unix platforms.";
|
||||
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
|
||||
Err(Error::new(ErrorKind::Bind(boxed)))
|
||||
},
|
||||
other => {
|
||||
let msg = format!("unsupported default listener address: {other}");
|
||||
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
|
||||
Err(Error::new(ErrorKind::Bind(boxed)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn tls_bindable<T>(&self, inner: T) -> TlsBindable<T> {
|
||||
#[cfg(feature = "tls")]
|
||||
if let Some(tls) = self.tls.clone() {
|
||||
return TlsBindable::Left(super::tls::TlsBindable { inner, tls });
|
||||
}
|
||||
|
||||
TlsBindable::Right(inner)
|
||||
}
|
||||
|
||||
pub fn bindable(&self) -> Result<impl Bindable, crate::Error> {
|
||||
self.base_bindable()
|
||||
.map(|b| b.map_either(|b| self.tls_bindable(b), |b| self.tls_bindable(b)))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,281 @@
|
|||
use std::fmt;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::any::Any;
|
||||
use std::net::{SocketAddr as TcpAddr, Ipv4Addr, AddrParseError};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde::de;
|
||||
|
||||
use crate::http::uncased::AsUncased;
|
||||
|
||||
pub trait EndpointAddr: fmt::Display + fmt::Debug + Sync + Send + Any { }
|
||||
|
||||
impl<T: fmt::Display + fmt::Debug + Sync + Send + Any> EndpointAddr for T {}
|
||||
|
||||
#[cfg(not(feature = "tls"))] type TlsInfo = Option<()>;
|
||||
#[cfg(feature = "tls")] type TlsInfo = Option<crate::tls::TlsConfig>;
|
||||
|
||||
/// # Conversions
|
||||
///
|
||||
/// * [`&str`] - parse with [`FromStr`]
|
||||
/// * [`tokio::net::unix::SocketAddr`] - must be path: [`ListenerAddr::Unix`]
|
||||
/// * [`std::net::SocketAddr`] - infallibly as [ListenerAddr::Tcp]
|
||||
/// * [`PathBuf`] - infallibly as [`ListenerAddr::Unix`]
|
||||
// TODO: Rename to something better. `Endpoint`?
|
||||
#[derive(Debug)]
|
||||
pub enum Endpoint {
|
||||
Tcp(TcpAddr),
|
||||
Unix(PathBuf),
|
||||
Tls(Arc<Endpoint>, TlsInfo),
|
||||
Custom(Arc<dyn EndpointAddr>),
|
||||
}
|
||||
|
||||
impl Endpoint {
|
||||
pub fn new<T: EndpointAddr>(value: T) -> Endpoint {
|
||||
Endpoint::Custom(Arc::new(value))
|
||||
}
|
||||
|
||||
pub fn tcp(&self) -> Option<TcpAddr> {
|
||||
match self {
|
||||
Endpoint::Tcp(addr) => Some(*addr),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unix(&self) -> Option<&Path> {
|
||||
match self {
|
||||
Endpoint::Unix(addr) => Some(addr),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tls(&self) -> Option<&Endpoint> {
|
||||
match self {
|
||||
Endpoint::Tls(addr, _) => Some(addr),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
pub fn tls_config(&self) -> Option<&crate::tls::TlsConfig> {
|
||||
match self {
|
||||
Endpoint::Tls(_, Some(ref config)) => Some(config),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "mtls")]
|
||||
pub fn mtls_config(&self) -> Option<&crate::mtls::MtlsConfig> {
|
||||
match self {
|
||||
Endpoint::Tls(_, Some(config)) => config.mutual(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn downcast<T: 'static>(&self) -> Option<&T> {
|
||||
match self {
|
||||
Endpoint::Tcp(addr) => (&*addr as &dyn Any).downcast_ref(),
|
||||
Endpoint::Unix(addr) => (&*addr as &dyn Any).downcast_ref(),
|
||||
Endpoint::Custom(addr) => (&*addr as &dyn Any).downcast_ref(),
|
||||
Endpoint::Tls(inner, ..) => inner.downcast(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_tcp(&self) -> bool {
|
||||
self.tcp().is_some()
|
||||
}
|
||||
|
||||
pub fn is_unix(&self) -> bool {
|
||||
self.unix().is_some()
|
||||
}
|
||||
|
||||
pub fn is_tls(&self) -> bool {
|
||||
self.tls().is_some()
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
pub fn with_tls(self, config: crate::tls::TlsConfig) -> Endpoint {
|
||||
if self.is_tls() {
|
||||
return self;
|
||||
}
|
||||
|
||||
Self::Tls(Arc::new(self), Some(config))
|
||||
}
|
||||
|
||||
pub fn assume_tls(self) -> Endpoint {
|
||||
if self.is_tls() {
|
||||
return self;
|
||||
}
|
||||
|
||||
Self::Tls(Arc::new(self), None)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Endpoint {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
use Endpoint::*;
|
||||
|
||||
match self {
|
||||
Tcp(addr) => write!(f, "http://{addr}"),
|
||||
Unix(addr) => write!(f, "unix:{}", addr.display()),
|
||||
Custom(inner) => inner.fmt(f),
|
||||
Tls(inner, c) => match (&**inner, c.as_ref()) {
|
||||
#[cfg(feature = "mtls")]
|
||||
(Tcp(i), Some(c)) if c.mutual().is_some() => write!(f, "https://{i} (TLS + MTLS)"),
|
||||
(Tcp(i), _) => write!(f, "https://{i} (TLS)"),
|
||||
#[cfg(feature = "mtls")]
|
||||
(i, Some(c)) if c.mutual().is_some() => write!(f, "{i} (TLS + MTLS)"),
|
||||
(inner, _) => write!(f, "{inner} (TLS)"),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::net::SocketAddr> for Endpoint {
|
||||
fn from(value: std::net::SocketAddr) -> Self {
|
||||
Self::Tcp(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::net::SocketAddrV4> for Endpoint {
|
||||
fn from(value: std::net::SocketAddrV4) -> Self {
|
||||
Self::Tcp(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::net::SocketAddrV6> for Endpoint {
|
||||
fn from(value: std::net::SocketAddrV6) -> Self {
|
||||
Self::Tcp(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PathBuf> for Endpoint {
|
||||
fn from(value: PathBuf) -> Self {
|
||||
Self::Unix(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl TryFrom<tokio::net::unix::SocketAddr> for Endpoint {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from(v: tokio::net::unix::SocketAddr) -> Result<Self, Self::Error> {
|
||||
v.as_pathname()
|
||||
.ok_or_else(|| std::io::Error::other("unix socket is not path"))
|
||||
.map(|path| Endpoint::Unix(path.to_path_buf()))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for Endpoint {
|
||||
type Error = AddrParseError;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
value.parse()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Endpoint {
|
||||
fn default() -> Self {
|
||||
Endpoint::Tcp(TcpAddr::new(Ipv4Addr::LOCALHOST.into(), 8000))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parses an address into a `ListenerAddr`.
|
||||
///
|
||||
/// The syntax is:
|
||||
///
|
||||
/// ```text
|
||||
/// listener_addr = 'tcp' ':' tcp_addr | 'unix' ':' unix_addr | tcp_addr
|
||||
/// tcp_addr := IP_ADDR | SOCKET_ADDR
|
||||
/// unix_addr := PATH
|
||||
///
|
||||
/// IP_ADDR := `std::net::IpAddr` string as defined by Rust
|
||||
/// SOCKET_ADDR := `std::net::SocketAddr` string as defined by Rust
|
||||
/// PATH := `PathBuf` (any UTF-8) string as defined by Rust
|
||||
/// ```
|
||||
///
|
||||
/// If `IP_ADDR` is specified, the port defaults to `8000`.
|
||||
impl FromStr for Endpoint {
|
||||
type Err = AddrParseError;
|
||||
|
||||
fn from_str(string: &str) -> Result<Self, Self::Err> {
|
||||
fn parse_tcp(string: &str, def_port: u16) -> Result<TcpAddr, AddrParseError> {
|
||||
string.parse().or_else(|_| string.parse().map(|ip| TcpAddr::new(ip, def_port)))
|
||||
}
|
||||
|
||||
if let Some((proto, string)) = string.split_once(':') {
|
||||
if proto.trim().as_uncased() == "tcp" {
|
||||
return parse_tcp(string.trim(), 8000).map(Self::Tcp);
|
||||
} else if proto.trim().as_uncased() == "unix" {
|
||||
return Ok(Self::Unix(PathBuf::from(string.trim())));
|
||||
}
|
||||
}
|
||||
|
||||
parse_tcp(string.trim(), 8000).map(Self::Tcp)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> de::Deserialize<'de> for Endpoint {
|
||||
fn deserialize<D: de::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
|
||||
struct Visitor;
|
||||
|
||||
impl<'de> de::Visitor<'de> for Visitor {
|
||||
type Value = Endpoint;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
formatter.write_str("TCP or Unix address")
|
||||
}
|
||||
|
||||
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
|
||||
v.parse::<Endpoint>().map_err(|e| E::custom(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
de.deserialize_any(Visitor)
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for Endpoint { }
|
||||
|
||||
impl PartialEq for Endpoint {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(Self::Tcp(l0), Self::Tcp(r0)) => l0 == r0,
|
||||
(Self::Unix(l0), Self::Unix(r0)) => l0 == r0,
|
||||
(Self::Tls(l0, _), Self::Tls(r0, _)) => l0 == r0,
|
||||
(Self::Custom(l0), Self::Custom(r0)) => l0.to_string() == r0.to_string(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<std::net::SocketAddr> for Endpoint {
|
||||
fn eq(&self, other: &std::net::SocketAddr) -> bool {
|
||||
self.tcp() == Some(*other)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<std::net::SocketAddrV4> for Endpoint {
|
||||
fn eq(&self, other: &std::net::SocketAddrV4) -> bool {
|
||||
self.tcp() == Some((*other).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<std::net::SocketAddrV6> for Endpoint {
|
||||
fn eq(&self, other: &std::net::SocketAddrV6) -> bool {
|
||||
self.tcp() == Some((*other).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<PathBuf> for Endpoint {
|
||||
fn eq(&self, other: &PathBuf) -> bool {
|
||||
self.unix() == Some(other.as_path())
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<Path> for Endpoint {
|
||||
fn eq(&self, other: &Path) -> bool {
|
||||
self.unix() == Some(other)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
use std::io;
|
||||
|
||||
use futures::TryFutureExt;
|
||||
use tokio_util::either::Either;
|
||||
|
||||
use crate::listener::{Connection, Endpoint};
|
||||
|
||||
pub trait Listener: Send + Sync {
|
||||
type Accept: Send;
|
||||
|
||||
type Connection: Connection;
|
||||
|
||||
async fn accept(&self) -> io::Result<Self::Accept>;
|
||||
|
||||
#[crate::async_bound(Send)]
|
||||
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection>;
|
||||
|
||||
fn socket_addr(&self) -> io::Result<Endpoint>;
|
||||
}
|
||||
|
||||
impl<L: Listener> Listener for &L {
|
||||
type Accept = L::Accept;
|
||||
|
||||
type Connection = L::Connection;
|
||||
|
||||
async fn accept(&self) -> io::Result<Self::Accept> {
|
||||
<L as Listener>::accept(self).await
|
||||
}
|
||||
|
||||
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> {
|
||||
<L as Listener>::connect(self, accept).await
|
||||
}
|
||||
|
||||
fn socket_addr(&self) -> io::Result<Endpoint> {
|
||||
<L as Listener>::socket_addr(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Listener, B: Listener> Listener for Either<A, B> {
|
||||
type Accept = Either<A::Accept, B::Accept>;
|
||||
|
||||
type Connection = Either<A::Connection, B::Connection>;
|
||||
|
||||
async fn accept(&self) -> io::Result<Self::Accept> {
|
||||
match self {
|
||||
Either::Left(l) => l.accept().map_ok(Either::Left).await,
|
||||
Either::Right(l) => l.accept().map_ok(Either::Right).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> {
|
||||
match (self, accept) {
|
||||
(Either::Left(l), Either::Left(a)) => l.connect(a).map_ok(Either::Left).await,
|
||||
(Either::Right(l), Either::Right(a)) => l.connect(a).map_ok(Either::Right).await,
|
||||
_ => unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
fn socket_addr(&self) -> io::Result<Endpoint> {
|
||||
match self {
|
||||
Either::Left(l) => l.socket_addr(),
|
||||
Either::Right(l) => l.socket_addr(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
mod cancellable;
|
||||
mod bounced;
|
||||
mod listener;
|
||||
mod endpoint;
|
||||
mod connection;
|
||||
mod bindable;
|
||||
mod default;
|
||||
|
||||
#[cfg(unix)]
|
||||
#[cfg_attr(nightly, doc(cfg(unix)))]
|
||||
pub mod unix;
|
||||
#[cfg(feature = "tls")]
|
||||
#[cfg_attr(nightly, doc(cfg(feature = "tls")))]
|
||||
pub mod tls;
|
||||
pub mod tcp;
|
||||
|
||||
pub use endpoint::*;
|
||||
pub use listener::*;
|
||||
pub use connection::*;
|
||||
pub use bindable::*;
|
||||
pub use default::*;
|
||||
|
||||
pub(crate) use cancellable::*;
|
||||
pub(crate) use bounced::*;
|
|
@ -0,0 +1,43 @@
|
|||
use std::io;
|
||||
|
||||
#[doc(inline)]
|
||||
pub use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
use crate::listener::{Listener, Bindable, Connection, Endpoint};
|
||||
|
||||
impl Bindable for std::net::SocketAddr {
|
||||
type Listener = TcpListener;
|
||||
|
||||
type Error = io::Error;
|
||||
|
||||
async fn bind(self) -> Result<Self::Listener, Self::Error> {
|
||||
TcpListener::bind(self).await
|
||||
}
|
||||
}
|
||||
|
||||
impl Listener for TcpListener {
|
||||
type Accept = Self::Connection;
|
||||
|
||||
type Connection = TcpStream;
|
||||
|
||||
async fn accept(&self) -> io::Result<Self::Accept> {
|
||||
let conn = self.accept().await?.0;
|
||||
let _ = conn.set_nodelay(true);
|
||||
let _ = conn.set_linger(None);
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
async fn connect(&self, conn: Self::Connection) -> io::Result<Self::Connection> {
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
fn socket_addr(&self) -> io::Result<Endpoint> {
|
||||
self.local_addr().map(Endpoint::Tcp)
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection for TcpStream {
|
||||
fn peer_address(&self) -> io::Result<Endpoint> {
|
||||
self.peer_addr().map(Endpoint::Tcp)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde::Deserialize;
|
||||
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
use crate::tls::{TlsConfig, Error};
|
||||
use crate::tls::util::{load_cert_chain, load_key, load_ca_certs};
|
||||
use crate::listener::{Listener, Bindable, Connection, Certificates, Endpoint};
|
||||
|
||||
#[doc(inline)]
|
||||
pub use tokio_rustls::server::TlsStream;
|
||||
|
||||
/// A TLS listener over some listener interface L.
|
||||
pub struct TlsListener<L> {
|
||||
listener: L,
|
||||
acceptor: TlsAcceptor,
|
||||
config: TlsConfig,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct TlsBindable<I> {
|
||||
#[serde(flatten)]
|
||||
pub inner: I,
|
||||
pub tls: TlsConfig,
|
||||
}
|
||||
|
||||
impl TlsConfig {
|
||||
pub(crate) fn acceptor(&self) -> Result<tokio_rustls::TlsAcceptor, Error> {
|
||||
let provider = rustls::crypto::CryptoProvider {
|
||||
cipher_suites: self.ciphers().map(|c| c.into()).collect(),
|
||||
..rustls::crypto::ring::default_provider()
|
||||
};
|
||||
|
||||
#[cfg(feature = "mtls")]
|
||||
let verifier = match self.mutual {
|
||||
Some(ref mtls) => {
|
||||
let ca_certs = load_ca_certs(&mut mtls.ca_certs_reader()?)?;
|
||||
let verifier = WebPkiClientVerifier::builder(Arc::new(ca_certs));
|
||||
match mtls.mandatory {
|
||||
true => verifier.build()?,
|
||||
false => verifier.allow_unauthenticated().build()?,
|
||||
}
|
||||
},
|
||||
None => WebPkiClientVerifier::no_client_auth(),
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "mtls"))]
|
||||
let verifier = WebPkiClientVerifier::no_client_auth();
|
||||
|
||||
let key = load_key(&mut self.key_reader()?)?;
|
||||
let cert_chain = load_cert_chain(&mut self.certs_reader()?)?;
|
||||
let mut tls_config = ServerConfig::builder_with_provider(Arc::new(provider))
|
||||
.with_safe_default_protocol_versions()?
|
||||
.with_client_cert_verifier(verifier)
|
||||
.with_single_cert(cert_chain, key)?;
|
||||
|
||||
tls_config.ignore_client_order = self.prefer_server_cipher_order;
|
||||
tls_config.session_storage = ServerSessionMemoryCache::new(1024);
|
||||
tls_config.ticketer = rustls::crypto::ring::Ticketer::new()?;
|
||||
tls_config.alpn_protocols = vec![b"http/1.1".to_vec()];
|
||||
if cfg!(feature = "http2") {
|
||||
tls_config.alpn_protocols.insert(0, b"h2".to_vec());
|
||||
}
|
||||
|
||||
Ok(TlsAcceptor::from(Arc::new(tls_config)))
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: Bindable> Bindable for TlsBindable<I> {
|
||||
type Listener = TlsListener<I::Listener>;
|
||||
|
||||
type Error = Error;
|
||||
|
||||
async fn bind(self) -> Result<Self::Listener, Self::Error> {
|
||||
Ok(TlsListener {
|
||||
acceptor: self.tls.acceptor()?,
|
||||
listener: self.inner.bind().await.map_err(|e| Error::Bind(Box::new(e)))?,
|
||||
config: self.tls,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: Listener + Sync> Listener for TlsListener<L>
|
||||
where L::Connection: Unpin
|
||||
{
|
||||
type Accept = L::Accept;
|
||||
|
||||
type Connection = TlsStream<L::Connection>;
|
||||
|
||||
async fn accept(&self) -> io::Result<Self::Accept> {
|
||||
self.listener.accept().await
|
||||
}
|
||||
|
||||
async fn connect(&self, accept: L::Accept) -> io::Result<Self::Connection> {
|
||||
let conn = self.listener.connect(accept).await?;
|
||||
self.acceptor.accept(conn).await
|
||||
}
|
||||
|
||||
fn socket_addr(&self) -> io::Result<Endpoint> {
|
||||
Ok(self.listener.socket_addr()?.with_tls(self.config.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Connection + Unpin> Connection for TlsStream<C> {
|
||||
fn peer_address(&self) -> io::Result<Endpoint> {
|
||||
Ok(self.get_ref().0.peer_address()?.assume_tls())
|
||||
}
|
||||
|
||||
#[cfg(feature = "mtls")]
|
||||
fn peer_certificates(&self) -> Option<Certificates<'_>> {
|
||||
let cert_chain = self.get_ref().1.peer_certificates()?;
|
||||
Some(Certificates::from(cert_chain))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,107 @@
|
|||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use tokio::time::{sleep, Duration};
|
||||
|
||||
use crate::fs::NamedFile;
|
||||
use crate::listener::{Listener, Bindable, Connection, Endpoint};
|
||||
use crate::util::unix;
|
||||
|
||||
pub use tokio::net::UnixStream;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UdsConfig {
|
||||
/// Socket address.
|
||||
pub path: PathBuf,
|
||||
/// Recreate a socket that already exists.
|
||||
pub reuse: Option<bool>,
|
||||
}
|
||||
|
||||
pub struct UdsListener {
|
||||
path: PathBuf,
|
||||
lock: Option<NamedFile>,
|
||||
listener: tokio::net::UnixListener,
|
||||
}
|
||||
|
||||
impl Bindable for UdsConfig {
|
||||
type Listener = UdsListener;
|
||||
|
||||
type Error = io::Error;
|
||||
|
||||
async fn bind(self) -> Result<Self::Listener, Self::Error> {
|
||||
let lock = if self.reuse.unwrap_or(true) {
|
||||
let lock_ext = match self.path.extension().and_then(|s| s.to_str()) {
|
||||
Some(ext) if !ext.is_empty() => format!("{}.lock", ext),
|
||||
_ => "lock".to_string()
|
||||
};
|
||||
|
||||
let mut opts = tokio::fs::File::options();
|
||||
opts.create(true).write(true);
|
||||
let lock_path = self.path.with_extension(lock_ext);
|
||||
let lock_file = NamedFile::open_with(lock_path, &opts).await?;
|
||||
|
||||
unix::lock_exlusive_nonblocking(lock_file.file())?;
|
||||
if self.path.exists() {
|
||||
tokio::fs::remove_file(&self.path).await?;
|
||||
}
|
||||
|
||||
Some(lock_file)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Sometimes, we get `AddrInUse`, even though we've tried deleting the
|
||||
// socket. If all is well, eventually the socket will _really_ be gone,
|
||||
// and this will succeed. So let's try a few times.
|
||||
let mut retries = 5;
|
||||
let listener = loop {
|
||||
match tokio::net::UnixListener::bind(&self.path) {
|
||||
Ok(listener) => break listener,
|
||||
Err(e) if self.path.exists() && lock.is_none() => return Err(e),
|
||||
Err(_) if retries > 0 => {
|
||||
retries -= 1;
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
},
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
};
|
||||
|
||||
Ok(UdsListener { lock, listener, path: self.path, })
|
||||
}
|
||||
}
|
||||
|
||||
impl Listener for UdsListener {
|
||||
type Accept = UnixStream;
|
||||
|
||||
type Connection = Self::Accept;
|
||||
|
||||
async fn accept(&self) -> io::Result<Self::Accept> {
|
||||
Ok(self.listener.accept().await?.0)
|
||||
}
|
||||
|
||||
async fn connect(&self, accept:Self::Accept) -> io::Result<Self::Connection> {
|
||||
Ok(accept)
|
||||
}
|
||||
|
||||
fn socket_addr(&self) -> io::Result<Endpoint> {
|
||||
self.listener.local_addr()?.try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection for UnixStream {
|
||||
fn peer_address(&self) -> io::Result<Endpoint> {
|
||||
self.local_addr()?.try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for UdsListener {
|
||||
fn drop(&mut self) {
|
||||
if let Some(lock) = &self.lock {
|
||||
let _ = std::fs::remove_file(&self.path);
|
||||
let _ = std::fs::remove_file(lock.path());
|
||||
let _ = unix::unlock_nonblocking(lock.file());
|
||||
} else {
|
||||
let _ = std::fs::remove_file(&self.path);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -4,7 +4,8 @@ use parking_lot::RwLock;
|
|||
|
||||
use crate::{Rocket, Phase, Orbit, Ignite, Error};
|
||||
use crate::local::asynchronous::{LocalRequest, LocalResponse};
|
||||
use crate::http::{Method, uri::Origin, private::cookie};
|
||||
use crate::http::{Method, uri::Origin};
|
||||
use crate::listener::Endpoint;
|
||||
|
||||
/// An `async` client to construct and dispatch local requests.
|
||||
///
|
||||
|
@ -55,9 +56,15 @@ pub struct Client {
|
|||
impl Client {
|
||||
pub(crate) async fn _new<P: Phase>(
|
||||
rocket: Rocket<P>,
|
||||
tracked: bool
|
||||
tracked: bool,
|
||||
secure: bool,
|
||||
) -> Result<Client, Error> {
|
||||
let rocket = rocket.local_launch().await?;
|
||||
let mut listener = Endpoint::new("local client");
|
||||
if secure {
|
||||
listener = listener.assume_tls();
|
||||
}
|
||||
|
||||
let rocket = rocket.local_launch(listener).await?;
|
||||
let cookies = RwLock::new(cookie::CookieJar::new());
|
||||
Ok(Client { rocket, cookies, tracked })
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ use super::{Client, LocalResponse};
|
|||
/// let client = Client::tracked(rocket::build()).await.expect("valid rocket");
|
||||
/// let req = client.post("/")
|
||||
/// .header(ContentType::JSON)
|
||||
/// .remote("127.0.0.1:8000".parse().unwrap())
|
||||
/// .remote("127.0.0.1:8000")
|
||||
/// .cookie(("name", "value"))
|
||||
/// .body(r#"{ "value": 42 }"#);
|
||||
///
|
||||
|
@ -86,14 +86,14 @@ impl<'c> LocalRequest<'c> {
|
|||
if self.inner().uri() == invalid {
|
||||
error!("invalid request URI: {:?}", invalid.path());
|
||||
return LocalResponse::new(self.request, move |req| {
|
||||
rocket.handle_error(Status::BadRequest, req)
|
||||
rocket.dispatch_error(Status::BadRequest, req)
|
||||
}).await
|
||||
}
|
||||
}
|
||||
|
||||
// Actually dispatch the request.
|
||||
let mut data = Data::local(self.data);
|
||||
let token = rocket.preprocess_request(&mut self.request, &mut data).await;
|
||||
let token = rocket.preprocess(&mut self.request, &mut data).await;
|
||||
let response = LocalResponse::new(self.request, move |req| {
|
||||
rocket.dispatch(token, req, data)
|
||||
}).await;
|
||||
|
|
|
@ -53,9 +53,14 @@ use crate::{Request, Response};
|
|||
///
|
||||
/// For more, see [the top-level documentation](../index.html#localresponse).
|
||||
pub struct LocalResponse<'c> {
|
||||
_request: Box<Request<'c>>,
|
||||
// XXX: SAFETY: This (dependent) field must come first due to drop order!
|
||||
response: Response<'c>,
|
||||
cookies: CookieJar<'c>,
|
||||
_request: Box<Request<'c>>,
|
||||
}
|
||||
|
||||
impl Drop for LocalResponse<'_> {
|
||||
fn drop(&mut self) { }
|
||||
}
|
||||
|
||||
impl<'c> LocalResponse<'c> {
|
||||
|
@ -64,7 +69,8 @@ impl<'c> LocalResponse<'c> {
|
|||
O: Future<Output = Response<'c>> + Send
|
||||
{
|
||||
// `LocalResponse` is a self-referential structure. In particular,
|
||||
// `inner` can refer to `_request` and its contents. As such, we must
|
||||
// `response` and `cookies` can refer to `_request` and its contents. As
|
||||
// such, we must
|
||||
// 1) Ensure `Request` has a stable address.
|
||||
//
|
||||
// This is done by `Box`ing the `Request`, using only the stable
|
||||
|
@ -97,7 +103,7 @@ impl<'c> LocalResponse<'c> {
|
|||
cookies.add_original(cookie.into_owned());
|
||||
}
|
||||
|
||||
LocalResponse { cookies, _request: boxed_req, response, }
|
||||
LocalResponse { _request: boxed_req, cookies, response, }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ pub struct Client {
|
|||
}
|
||||
|
||||
impl Client {
|
||||
fn _new<P: Phase>(rocket: Rocket<P>, tracked: bool) -> Result<Client, Error> {
|
||||
fn _new<P: Phase>(rocket: Rocket<P>, tracked: bool, secure: bool) -> Result<Client, Error> {
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
.thread_name("rocket-local-client-worker-thread")
|
||||
.worker_threads(1)
|
||||
|
@ -39,7 +39,7 @@ impl Client {
|
|||
.expect("create tokio runtime");
|
||||
|
||||
// Initialize the Rocket instance
|
||||
let inner = Some(runtime.block_on(asynchronous::Client::_new(rocket, tracked))?);
|
||||
let inner = Some(runtime.block_on(asynchronous::Client::_new(rocket, tracked, secure))?);
|
||||
Ok(Self { inner, runtime: RefCell::new(runtime) })
|
||||
}
|
||||
|
||||
|
@ -73,7 +73,7 @@ impl Client {
|
|||
|
||||
#[inline(always)]
|
||||
pub(crate) fn _with_raw_cookies<F, T>(&self, f: F) -> T
|
||||
where F: FnOnce(&crate::http::private::cookie::CookieJar) -> T
|
||||
where F: FnOnce(&cookie::CookieJar) -> T
|
||||
{
|
||||
self.inner()._with_raw_cookies(f)
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ use super::{Client, LocalResponse};
|
|||
/// let client = Client::tracked(rocket::build()).expect("valid rocket");
|
||||
/// let req = client.post("/")
|
||||
/// .header(ContentType::JSON)
|
||||
/// .remote("127.0.0.1:8000".parse().unwrap())
|
||||
/// .remote("127.0.0.1:8000")
|
||||
/// .cookie(("name", "value"))
|
||||
/// .body(r#"{ "value": 42 }"#);
|
||||
///
|
||||
|
|
|
@ -68,7 +68,12 @@ macro_rules! pub_client_impl {
|
|||
/// ```
|
||||
#[inline(always)]
|
||||
pub $($prefix)? fn tracked<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
|
||||
Self::_new(rocket, true) $(.$suffix)?
|
||||
Self::_new(rocket, true, false) $(.$suffix)?
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub $($prefix)? fn tracked_secure<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
|
||||
Self::_new(rocket, true, true) $(.$suffix)?
|
||||
}
|
||||
|
||||
/// Construct a new `Client` from an instance of `Rocket` _without_
|
||||
|
@ -92,7 +97,11 @@ macro_rules! pub_client_impl {
|
|||
/// let client = Client::untracked(rocket);
|
||||
/// ```
|
||||
pub $($prefix)? fn untracked<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
|
||||
Self::_new(rocket, false) $(.$suffix)?
|
||||
Self::_new(rocket, false, false) $(.$suffix)?
|
||||
}
|
||||
|
||||
pub $($prefix)? fn untracked_secure<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
|
||||
Self::_new(rocket, false, true) $(.$suffix)?
|
||||
}
|
||||
|
||||
/// Terminates `Client` by initiating a graceful shutdown via
|
||||
|
@ -135,15 +144,6 @@ macro_rules! pub_client_impl {
|
|||
Self::tracked(rocket.configure(figment)) $(.$suffix)?
|
||||
}
|
||||
|
||||
/// Deprecated alias to [`Client::tracked()`].
|
||||
#[deprecated(
|
||||
since = "0.6.0-dev",
|
||||
note = "choose between `Client::untracked()` and `Client::tracked()`"
|
||||
)]
|
||||
pub $($prefix)? fn new<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
|
||||
Self::tracked(rocket) $(.$suffix)?
|
||||
}
|
||||
|
||||
/// Returns a reference to the `Rocket` this client is creating requests
|
||||
/// for.
|
||||
///
|
||||
|
|
|
@ -97,24 +97,40 @@ macro_rules! pub_request_impl {
|
|||
self._request_mut().add_header(header.into());
|
||||
}
|
||||
|
||||
/// Set the remote address of this request.
|
||||
/// Set the remote address of this request to `address`.
|
||||
///
|
||||
/// `address` may be any type that [can be converted into a `ListenerAddr`].
|
||||
/// If `address` fails to convert, the remote is left unchanged.
|
||||
///
|
||||
/// [can be converted into a `ListenerAddr`]: crate::listener::ListenerAddr#conversions
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// Set the remote address to "8.8.8.8:80":
|
||||
///
|
||||
/// ```rust
|
||||
/// use std::net::{SocketAddrV4, Ipv4Addr};
|
||||
///
|
||||
#[doc = $import]
|
||||
///
|
||||
/// # Client::_test(|_, request, _| {
|
||||
/// let request: LocalRequest = request;
|
||||
/// let address = "8.8.8.8:80".parse().unwrap();
|
||||
/// let req = request.remote(address);
|
||||
/// let req = request.remote("8.8.8.8:80");
|
||||
///
|
||||
/// let addr = SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8).into(), 80);
|
||||
/// assert_eq!(req.inner().remote().unwrap(), &addr);
|
||||
/// # });
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn remote(mut self, address: std::net::SocketAddr) -> Self {
|
||||
self.set_remote(address);
|
||||
pub fn remote<T>(mut self, endpoint: T) -> Self
|
||||
where T: TryInto<crate::listener::Endpoint>
|
||||
{
|
||||
if let Ok(endpoint) = endpoint.try_into() {
|
||||
self.set_remote(endpoint);
|
||||
} else {
|
||||
warn!("remote failed to convert");
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -228,11 +244,13 @@ macro_rules! pub_request_impl {
|
|||
#[cfg(feature = "mtls")]
|
||||
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
|
||||
pub fn identity<C: std::io::Read>(mut self, reader: C) -> Self {
|
||||
use crate::http::{tls::util::load_cert_chain, private::Certificates};
|
||||
use std::sync::Arc;
|
||||
use crate::tls::util::load_cert_chain;
|
||||
use crate::listener::Certificates;
|
||||
|
||||
let mut reader = std::io::BufReader::new(reader);
|
||||
let certs = load_cert_chain(&mut reader).map(Certificates::from);
|
||||
self._request_mut().connection.client_certificates = certs.ok();
|
||||
self._request_mut().connection.peer_certs = certs.ok().map(Arc::new);
|
||||
self
|
||||
}
|
||||
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
//! Support for mutual TLS client certificates.
|
||||
//!
|
||||
//! For details on how to configure mutual TLS, see
|
||||
//! [`MutualTls`](crate::config::MutualTls) and the [TLS
|
||||
//! guide](https://rocket.rs/master/guide/configuration/#tls). See
|
||||
//! [`Certificate`] for a request guard that validated, verifies, and retrieves
|
||||
//! client certificates.
|
||||
|
||||
#[doc(inline)]
|
||||
pub use crate::http::tls::mtls::*;
|
||||
|
||||
use crate::request::{Request, FromRequest, Outcome};
|
||||
use crate::outcome::{try_outcome, IntoOutcome};
|
||||
use crate::http::Status;
|
||||
|
||||
#[crate::async_trait]
|
||||
impl<'r> FromRequest<'r> for Certificate<'r> {
|
||||
type Error = Error;
|
||||
|
||||
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||
let certs = req.connection.client_certificates.as_ref().or_forward(Status::Unauthorized);
|
||||
let data = try_outcome!(try_outcome!(certs).chain_data().or_forward(Status::Unauthorized));
|
||||
Certificate::parse(data).or_error(Status::Unauthorized)
|
||||
}
|
||||
}
|
|
@ -1,51 +1,8 @@
|
|||
pub mod oid {
|
||||
//! Lower-level OID types re-exported from
|
||||
//! [`oid_registry`](https://docs.rs/oid-registry/0.4) and
|
||||
//! [`der-parser`](https://docs.rs/der-parser/7).
|
||||
|
||||
pub use x509_parser::oid_registry::*;
|
||||
pub use x509_parser::objects::*;
|
||||
}
|
||||
|
||||
pub mod bigint {
|
||||
//! Signed and unsigned big integer types re-exported from
|
||||
//! [`num_bigint`](https://docs.rs/num-bigint/0.4).
|
||||
pub use x509_parser::der_parser::num_bigint::*;
|
||||
}
|
||||
|
||||
pub mod x509 {
|
||||
//! Lower-level X.509 types re-exported from
|
||||
//! [`x509_parser`](https://docs.rs/x509-parser/0.13).
|
||||
//!
|
||||
//! Lack of documentation is directly inherited from the source crate.
|
||||
//! Prefer to use Rocket's wrappers when possible.
|
||||
|
||||
pub use x509_parser::certificate::*;
|
||||
pub use x509_parser::cri_attributes::*;
|
||||
pub use x509_parser::error::*;
|
||||
pub use x509_parser::extensions::*;
|
||||
pub use x509_parser::revocation_list::*;
|
||||
pub use x509_parser::time::*;
|
||||
pub use x509_parser::x509::*;
|
||||
pub use x509_parser::der_parser::der;
|
||||
pub use x509_parser::der_parser::ber;
|
||||
pub use x509_parser::traits::*;
|
||||
}
|
||||
|
||||
use std::fmt;
|
||||
use std::ops::Deref;
|
||||
use std::num::NonZeroUsize;
|
||||
|
||||
use ref_cast::RefCast;
|
||||
use x509_parser::nom;
|
||||
use x509::{ParsedExtension, X509Name, X509Certificate, TbsCertificate, X509Error, FromDer};
|
||||
use oid::OID_X509_EXT_SUBJECT_ALT_NAME as SUBJECT_ALT_NAME;
|
||||
|
||||
use crate::listener::CertificateDer;
|
||||
|
||||
/// A type alias for [`Result`](std::result::Result) with the error type set to
|
||||
/// [`Error`].
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
use crate::mtls::{x509, oid, bigint, Name, Result, Error};
|
||||
use crate::request::{Request, FromRequest, Outcome};
|
||||
use crate::http::Status;
|
||||
|
||||
/// A request guard for validated, verified client certificates.
|
||||
///
|
||||
|
@ -143,60 +100,42 @@ pub type Result<T, E = Error> = std::result::Result<T, E>;
|
|||
/// ```
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct Certificate<'a> {
|
||||
x509: X509Certificate<'a>,
|
||||
data: &'a CertificateDer,
|
||||
x509: x509::X509Certificate<'a>,
|
||||
data: &'a CertificateDer<'a>,
|
||||
}
|
||||
|
||||
/// An X.509 Distinguished Name (DN) found in a [`Certificate`].
|
||||
///
|
||||
/// This type is a wrapper over [`x509::X509Name`] with convenient methods and
|
||||
/// complete documentation. Should the data exposed by the inherent methods not
|
||||
/// suffice, this type derefs to [`x509::X509Name`].
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, PartialEq, RefCast)]
|
||||
pub struct Name<'a>(X509Name<'a>);
|
||||
pub use rustls::pki_types::CertificateDer;
|
||||
|
||||
/// An error returned by the [`Certificate`] request guard.
|
||||
///
|
||||
/// To retrieve this error in a handler, use an `mtls::Result<Certificate>`
|
||||
/// guard type:
|
||||
///
|
||||
/// ```rust
|
||||
/// # extern crate rocket;
|
||||
/// # use rocket::get;
|
||||
/// use rocket::mtls::{self, Certificate};
|
||||
///
|
||||
/// #[get("/auth")]
|
||||
/// fn auth(cert: mtls::Result<Certificate<'_>>) {
|
||||
/// match cert {
|
||||
/// Ok(cert) => { /* do something with the client cert */ },
|
||||
/// Err(e) => { /* do something with the error */ },
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
#[non_exhaustive]
|
||||
pub enum Error {
|
||||
/// The certificate chain presented by the client had no certificates.
|
||||
Empty,
|
||||
/// The certificate contained neither a subject nor a subjectAlt extension.
|
||||
NoSubject,
|
||||
/// There is no subject and the subjectAlt is not marked as critical.
|
||||
NonCriticalSubjectAlt,
|
||||
/// An error occurred while parsing the certificate.
|
||||
Parse(X509Error),
|
||||
/// The certificate parsed partially but is incomplete.
|
||||
///
|
||||
/// If `Some(n)`, then `n` more bytes were expected. Otherwise, the number
|
||||
/// of expected bytes is unknown.
|
||||
Incomplete(Option<NonZeroUsize>),
|
||||
/// The certificate contained `.0` bytes of trailing data.
|
||||
Trailing(usize),
|
||||
#[crate::async_trait]
|
||||
impl<'r> FromRequest<'r> for Certificate<'r> {
|
||||
type Error = Error;
|
||||
|
||||
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||
use crate::outcome::{try_outcome, IntoOutcome};
|
||||
|
||||
let certs = req.connection
|
||||
.peer_certs
|
||||
.as_ref()
|
||||
.or_forward(Status::Unauthorized);
|
||||
|
||||
let chain = try_outcome!(certs);
|
||||
Certificate::parse(chain.inner()).or_error(Status::Unauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Certificate<'a> {
|
||||
fn parse_one(raw: &[u8]) -> Result<X509Certificate<'_>> {
|
||||
let (left, x509) = X509Certificate::from_der(raw)?;
|
||||
/// PRIVATE: For internal Rocket use only!
|
||||
fn parse<'r>(chain: &'r [CertificateDer<'r>]) -> Result<Certificate<'r>> {
|
||||
let data = chain.first().ok_or_else(|| Error::Empty)?;
|
||||
let x509 = Certificate::parse_one(&*data)?;
|
||||
Ok(Certificate { x509, data })
|
||||
}
|
||||
|
||||
fn parse_one(raw: &[u8]) -> Result<x509::X509Certificate<'_>> {
|
||||
use oid::OID_X509_EXT_SUBJECT_ALT_NAME as SUBJECT_ALT_NAME;
|
||||
use x509_parser::traits::FromDer;
|
||||
|
||||
let (left, x509) = x509::X509Certificate::from_der(raw)?;
|
||||
if !left.is_empty() {
|
||||
return Err(Error::Trailing(left.len()));
|
||||
}
|
||||
|
@ -204,7 +143,7 @@ impl<'a> Certificate<'a> {
|
|||
// Ensure we have a subject or a subjectAlt.
|
||||
if x509.subject().as_raw().is_empty() {
|
||||
if let Some(ext) = x509.extensions().iter().find(|e| e.oid == SUBJECT_ALT_NAME) {
|
||||
if !matches!(ext.parsed_extension(), ParsedExtension::SubjectAlternativeName(..)) {
|
||||
if let x509::ParsedExtension::SubjectAlternativeName(..) = ext.parsed_extension() {
|
||||
return Err(Error::NoSubject);
|
||||
} else if !ext.critical {
|
||||
return Err(Error::NonCriticalSubjectAlt);
|
||||
|
@ -218,18 +157,10 @@ impl<'a> Certificate<'a> {
|
|||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn inner(&self) -> &TbsCertificate<'a> {
|
||||
fn inner(&self) -> &x509::TbsCertificate<'a> {
|
||||
&self.x509.tbs_certificate
|
||||
}
|
||||
|
||||
/// PRIVATE: For internal Rocket use only!
|
||||
#[doc(hidden)]
|
||||
pub fn parse(chain: &[CertificateDer]) -> Result<Certificate<'_>> {
|
||||
let data = chain.first().ok_or_else(|| Error::Empty)?;
|
||||
let x509 = Certificate::parse_one(&data.0)?;
|
||||
Ok(Certificate { x509, data })
|
||||
}
|
||||
|
||||
/// Returns the serial number of the X.509 certificate.
|
||||
///
|
||||
/// # Example
|
||||
|
@ -387,176 +318,14 @@ impl<'a> Certificate<'a> {
|
|||
/// }
|
||||
/// ```
|
||||
pub fn as_bytes(&self) -> &'a [u8] {
|
||||
&self.data.0
|
||||
&*self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Deref for Certificate<'a> {
|
||||
type Target = TbsCertificate<'a>;
|
||||
impl<'a> std::ops::Deref for Certificate<'a> {
|
||||
type Target = x509::TbsCertificate<'a>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.inner()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Name<'a> {
|
||||
/// Returns the _first_ UTF-8 _string_ common name, if any.
|
||||
///
|
||||
/// Note that common names need not be UTF-8 strings, or strings at all.
|
||||
/// This method returns the first common name attribute that is.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # #[macro_use] extern crate rocket;
|
||||
/// use rocket::mtls::Certificate;
|
||||
///
|
||||
/// #[get("/auth")]
|
||||
/// fn auth(cert: Certificate<'_>) {
|
||||
/// if let Some(name) = cert.subject().common_name() {
|
||||
/// println!("Hello, {}!", name);
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub fn common_name(&self) -> Option<&'a str> {
|
||||
self.common_names().next()
|
||||
}
|
||||
|
||||
/// Returns an iterator over all of the UTF-8 _string_ common names in
|
||||
/// `self`.
|
||||
///
|
||||
/// Note that common names need not be UTF-8 strings, or strings at all.
|
||||
/// This method filters the common names in `self` to those that are. Use
|
||||
/// the raw [`iter_common_name()`](#method.iter_common_name) to iterate over
|
||||
/// all value types.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # #[macro_use] extern crate rocket;
|
||||
/// use rocket::mtls::Certificate;
|
||||
///
|
||||
/// #[get("/auth")]
|
||||
/// fn auth(cert: Certificate<'_>) {
|
||||
/// for name in cert.issuer().common_names() {
|
||||
/// println!("Issued by {}.", name);
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub fn common_names(&self) -> impl Iterator<Item = &'a str> + '_ {
|
||||
self.iter_by_oid(&oid::OID_X509_COMMON_NAME).filter_map(|n| n.as_str().ok())
|
||||
}
|
||||
|
||||
/// Returns the _first_ UTF-8 _string_ email address, if any.
|
||||
///
|
||||
/// Note that email addresses need not be UTF-8 strings, or strings at all.
|
||||
/// This method returns the first email address attribute that is.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # #[macro_use] extern crate rocket;
|
||||
/// use rocket::mtls::Certificate;
|
||||
///
|
||||
/// #[get("/auth")]
|
||||
/// fn auth(cert: Certificate<'_>) {
|
||||
/// if let Some(email) = cert.subject().email() {
|
||||
/// println!("Hello, {}!", email);
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub fn email(&self) -> Option<&'a str> {
|
||||
self.emails().next()
|
||||
}
|
||||
|
||||
/// Returns an iterator over all of the UTF-8 _string_ email addresses in
|
||||
/// `self`.
|
||||
///
|
||||
/// Note that email addresses need not be UTF-8 strings, or strings at all.
|
||||
/// This method filters the email address in `self` to those that are. Use
|
||||
/// the raw [`iter_email()`](#method.iter_email) to iterate over all value
|
||||
/// types.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # #[macro_use] extern crate rocket;
|
||||
/// use rocket::mtls::Certificate;
|
||||
///
|
||||
/// #[get("/auth")]
|
||||
/// fn auth(cert: Certificate<'_>) {
|
||||
/// for email in cert.subject().emails() {
|
||||
/// println!("Reach me at: {}", email);
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub fn emails(&self) -> impl Iterator<Item = &'a str> + '_ {
|
||||
self.iter_by_oid(&oid::OID_PKCS9_EMAIL_ADDRESS).filter_map(|n| n.as_str().ok())
|
||||
}
|
||||
|
||||
/// Returns `true` if `self` has no data.
|
||||
///
|
||||
/// When this is the case for a `subject()`, the subject data can be found
|
||||
/// in the `subjectAlt` [`extension()`](Certificate::extensions()).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # #[macro_use] extern crate rocket;
|
||||
/// use rocket::mtls::Certificate;
|
||||
///
|
||||
/// #[get("/auth")]
|
||||
/// fn auth(cert: Certificate<'_>) {
|
||||
/// let no_data = cert.subject().is_empty();
|
||||
/// }
|
||||
/// ```
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.0.as_raw().is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Deref for Name<'a> {
|
||||
type Target = X509Name<'a>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Name<'_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Error::Parse(e) => write!(f, "parse error: {}", e),
|
||||
Error::Incomplete(_) => write!(f, "incomplete certificate data"),
|
||||
Error::Trailing(n) => write!(f, "found {} trailing bytes", n),
|
||||
Error::Empty => write!(f, "empty certificate chain"),
|
||||
Error::NoSubject => write!(f, "empty subject without subjectAlt"),
|
||||
Error::NonCriticalSubjectAlt => write!(f, "empty subject without critical subjectAlt"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<nom::Err<X509Error>> for Error {
|
||||
fn from(e: nom::Err<X509Error>) -> Self {
|
||||
match e {
|
||||
nom::Err::Incomplete(nom::Needed::Unknown) => Error::Incomplete(None),
|
||||
nom::Err::Incomplete(nom::Needed::Size(n)) => Error::Incomplete(Some(n)),
|
||||
nom::Err::Error(e) | nom::Err::Failure(e) => Error::Parse(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {
|
||||
// fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
// match self {
|
||||
// Error::Parse(e) => Some(e),
|
||||
// _ => None
|
||||
// }
|
||||
// }
|
||||
}
|
|
@ -0,0 +1,212 @@
|
|||
use std::io;
|
||||
|
||||
use figment::value::magic::{RelativePathBuf, Either};
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
/// Mutual TLS configuration.
|
||||
///
|
||||
/// Configuration works in concert with the [`mtls`](crate::mtls) module, which
|
||||
/// provides a request guard to validate, verify, and retrieve client
|
||||
/// certificates in routes.
|
||||
///
|
||||
/// By default, mutual TLS is disabled and client certificates are not required,
|
||||
/// validated or verified. To enable mutual TLS, the `mtls` feature must be
|
||||
/// enabled and support configured via two `tls.mutual` parameters:
|
||||
///
|
||||
/// * `ca_certs`
|
||||
///
|
||||
/// A required path to a PEM file or raw bytes to a DER-encoded X.509 TLS
|
||||
/// certificate chain for the certificate authority to verify client
|
||||
/// certificates against. When a path is configured in a file, such as
|
||||
/// `Rocket.toml`, relative paths are interpreted as relative to the source
|
||||
/// file's directory.
|
||||
///
|
||||
/// * `mandatory`
|
||||
///
|
||||
/// An optional boolean that control whether client authentication is
|
||||
/// required.
|
||||
///
|
||||
/// When `true`, client authentication is required. TLS connections where
|
||||
/// the client does not present a certificate are immediately terminated.
|
||||
/// When `false`, the client is not required to present a certificate. In
|
||||
/// either case, if a certificate _is_ presented, it must be valid or the
|
||||
/// connection is terminated.
|
||||
///
|
||||
/// In a `Rocket.toml`, configuration might look like:
|
||||
///
|
||||
/// ```toml
|
||||
/// [default.tls.mutual]
|
||||
/// ca_certs = "/ssl/ca_cert.pem"
|
||||
/// mandatory = true # when absent, defaults to false
|
||||
/// ```
|
||||
///
|
||||
/// Programmatically, configuration might look like:
|
||||
///
|
||||
/// ```rust
|
||||
/// # #[macro_use] extern crate rocket;
|
||||
/// use rocket::mtls::MtlsConfig;
|
||||
/// use rocket::figment::providers::Serialized;
|
||||
///
|
||||
/// #[launch]
|
||||
/// fn rocket() -> _ {
|
||||
/// let mtls = MtlsConfig::from_path("/ssl/ca_cert.pem");
|
||||
/// rocket::custom(rocket::Config::figment().merge(("tls.mutual", mtls)))
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// Once mTLS is configured, the [`mtls::Certificate`](crate::mtls::Certificate)
|
||||
/// request guard can be used to retrieve client certificates in routes.
|
||||
#[derive(PartialEq, Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MtlsConfig {
|
||||
/// Path to a PEM file with, or raw bytes for, DER-encoded Certificate
|
||||
/// Authority certificates which will be used to verify client-presented
|
||||
/// certificates.
|
||||
// TODO: Support more than one CA root.
|
||||
pub(crate) ca_certs: Either<RelativePathBuf, Vec<u8>>,
|
||||
/// Whether the client is required to present a certificate.
|
||||
///
|
||||
/// When `true`, the client is required to present a valid certificate to
|
||||
/// proceed with TLS. When `false`, the client is not required to present a
|
||||
/// certificate. In either case, if a certificate _is_ presented, it must be
|
||||
/// valid or the connection is terminated.
|
||||
#[serde(default)]
|
||||
#[serde(deserialize_with = "figment::util::bool_from_str_or_int")]
|
||||
pub mandatory: bool,
|
||||
}
|
||||
|
||||
impl MtlsConfig {
|
||||
/// Constructs a `MtlsConfig` from a path to a PEM file with a certificate
|
||||
/// authority `ca_certs` DER-encoded X.509 TLS certificate chain. This
|
||||
/// method does no validation; it simply creates a structure suitable for
|
||||
/// passing into a [`TlsConfig`].
|
||||
///
|
||||
/// These certificates will be used to verify client-presented certificates
|
||||
/// in TLS connections.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use rocket::mtls::MtlsConfig;
|
||||
///
|
||||
/// let tls_config = MtlsConfig::from_path("/ssl/ca_certs.pem");
|
||||
/// ```
|
||||
pub fn from_path<C: AsRef<std::path::Path>>(ca_certs: C) -> Self {
|
||||
MtlsConfig {
|
||||
ca_certs: Either::Left(ca_certs.as_ref().to_path_buf().into()),
|
||||
mandatory: Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs a `MtlsConfig` from a byte buffer to a certificate authority
|
||||
/// `ca_certs` DER-encoded X.509 TLS certificate chain. This method does no
|
||||
/// validation; it simply creates a structure suitable for passing into a
|
||||
/// [`TlsConfig`].
|
||||
///
|
||||
/// These certificates will be used to verify client-presented certificates
|
||||
/// in TLS connections.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use rocket::mtls::MtlsConfig;
|
||||
///
|
||||
/// # let ca_certs_buf = &[];
|
||||
/// let mtls_config = MtlsConfig::from_bytes(ca_certs_buf);
|
||||
/// ```
|
||||
pub fn from_bytes(ca_certs: &[u8]) -> Self {
|
||||
MtlsConfig {
|
||||
ca_certs: Either::Right(ca_certs.to_vec()),
|
||||
mandatory: Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets whether client authentication is required. Disabled by default.
|
||||
///
|
||||
/// When `true`, client authentication will be required. TLS connections
|
||||
/// where the client does not present a certificate will be immediately
|
||||
/// terminated. When `false`, the client is not required to present a
|
||||
/// certificate. In either case, if a certificate _is_ presented, it must be
|
||||
/// valid or the connection is terminated.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use rocket::mtls::MtlsConfig;
|
||||
///
|
||||
/// # let ca_certs_buf = &[];
|
||||
/// let mtls_config = MtlsConfig::from_bytes(ca_certs_buf).mandatory(true);
|
||||
/// ```
|
||||
pub fn mandatory(mut self, mandatory: bool) -> Self {
|
||||
self.mandatory = mandatory;
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns the value of the `ca_certs` parameter.
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use rocket::mtls::MtlsConfig;
|
||||
///
|
||||
/// # let ca_certs_buf = &[];
|
||||
/// let mtls_config = MtlsConfig::from_bytes(ca_certs_buf).mandatory(true);
|
||||
/// assert_eq!(mtls_config.ca_certs().unwrap_right(), ca_certs_buf);
|
||||
/// ```
|
||||
pub fn ca_certs(&self) -> either::Either<std::path::PathBuf, &[u8]> {
|
||||
match &self.ca_certs {
|
||||
Either::Left(path) => either::Either::Left(path.relative()),
|
||||
Either::Right(bytes) => either::Either::Right(&bytes),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn ca_certs_reader(&self) -> io::Result<Box<dyn io::BufRead + Sync + Send>> {
|
||||
crate::tls::config::to_reader(&self.ca_certs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::Path;
|
||||
use figment::{Figment, providers::{Toml, Format}};
|
||||
|
||||
use crate::mtls::MtlsConfig;
|
||||
|
||||
#[test]
|
||||
fn test_mtls_config() {
|
||||
figment::Jail::expect_with(|jail| {
|
||||
jail.create_file("MTLS.toml", r#"
|
||||
certs = "/ssl/cert.pem"
|
||||
key = "/ssl/key.pem"
|
||||
"#)?;
|
||||
|
||||
let figment = || Figment::from(Toml::file("MTLS.toml"));
|
||||
figment().extract::<MtlsConfig>().expect_err("no ca");
|
||||
|
||||
jail.create_file("MTLS.toml", r#"
|
||||
ca_certs = "/ssl/ca.pem"
|
||||
"#)?;
|
||||
|
||||
let mtls: MtlsConfig = figment().extract()?;
|
||||
assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem"));
|
||||
assert!(!mtls.mandatory);
|
||||
|
||||
jail.create_file("MTLS.toml", r#"
|
||||
ca_certs = "/ssl/ca.pem"
|
||||
mandatory = true
|
||||
"#)?;
|
||||
|
||||
let mtls: MtlsConfig = figment().extract()?;
|
||||
assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem"));
|
||||
assert!(mtls.mandatory);
|
||||
|
||||
jail.create_file("MTLS.toml", r#"
|
||||
ca_certs = "relative/ca.pem"
|
||||
"#)?;
|
||||
|
||||
let mtls: MtlsConfig = figment().extract()?;
|
||||
assert_eq!(mtls.ca_certs().unwrap_left(), jail.directory().join("relative/ca.pem"));
|
||||
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
use std::fmt;
|
||||
use std::num::NonZeroUsize;
|
||||
|
||||
use crate::mtls::x509::{self, nom};
|
||||
|
||||
/// An error returned by the [`Certificate`] request guard.
|
||||
///
|
||||
/// To retrieve this error in a handler, use an `mtls::Result<Certificate>`
|
||||
/// guard type:
|
||||
///
|
||||
/// ```rust
|
||||
/// # extern crate rocket;
|
||||
/// # use rocket::get;
|
||||
/// use rocket::mtls::{self, Certificate};
|
||||
///
|
||||
/// #[get("/auth")]
|
||||
/// fn auth(cert: mtls::Result<Certificate<'_>>) {
|
||||
/// match cert {
|
||||
/// Ok(cert) => { /* do something with the client cert */ },
|
||||
/// Err(e) => { /* do something with the error */ },
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
#[non_exhaustive]
|
||||
pub enum Error {
|
||||
/// The certificate chain presented by the client had no certificates.
|
||||
Empty,
|
||||
/// The certificate contained neither a subject nor a subjectAlt extension.
|
||||
NoSubject,
|
||||
/// There is no subject and the subjectAlt is not marked as critical.
|
||||
NonCriticalSubjectAlt,
|
||||
/// An error occurred while parsing the certificate.
|
||||
Parse(x509::X509Error),
|
||||
/// The certificate parsed partially but is incomplete.
|
||||
///
|
||||
/// If `Some(n)`, then `n` more bytes were expected. Otherwise, the number
|
||||
/// of expected bytes is unknown.
|
||||
Incomplete(Option<NonZeroUsize>),
|
||||
/// The certificate contained `.0` bytes of trailing data.
|
||||
Trailing(usize),
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Error::Parse(e) => write!(f, "parse error: {}", e),
|
||||
Error::Incomplete(_) => write!(f, "incomplete certificate data"),
|
||||
Error::Trailing(n) => write!(f, "found {} trailing bytes", n),
|
||||
Error::Empty => write!(f, "empty certificate chain"),
|
||||
Error::NoSubject => write!(f, "empty subject without subjectAlt"),
|
||||
Error::NonCriticalSubjectAlt => write!(f, "empty subject without critical subjectAlt"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<nom::Err<x509::X509Error>> for Error {
|
||||
fn from(e: nom::Err<x509::X509Error>) -> Self {
|
||||
match e {
|
||||
nom::Err::Incomplete(nom::Needed::Unknown) => Error::Incomplete(None),
|
||||
nom::Err::Incomplete(nom::Needed::Size(n)) => Error::Incomplete(Some(n)),
|
||||
nom::Err::Error(e) | nom::Err::Failure(e) => Error::Parse(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {
|
||||
// fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
// match self {
|
||||
// Error::Parse(e) => Some(e),
|
||||
// _ => None
|
||||
// }
|
||||
// }
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
//! Support for mutual TLS client certificates.
|
||||
//!
|
||||
//! For details on how to configure mutual TLS, see
|
||||
//! [`MutualTls`](crate::config::MutualTls) and the [TLS
|
||||
//! guide](https://rocket.rs/master/guide/configuration/#tls). See
|
||||
//! [`Certificate`] for a request guard that validated, verifies, and retrieves
|
||||
//! client certificates.
|
||||
|
||||
pub mod oid {
|
||||
//! Lower-level OID types re-exported from
|
||||
//! [`oid_registry`](https://docs.rs/oid-registry/0.4) and
|
||||
//! [`der-parser`](https://docs.rs/der-parser/7).
|
||||
|
||||
pub use x509_parser::oid_registry::*;
|
||||
pub use x509_parser::objects::*;
|
||||
}
|
||||
|
||||
pub mod bigint {
|
||||
//! Signed and unsigned big integer types re-exported from
|
||||
//! [`num_bigint`](https://docs.rs/num-bigint/0.4).
|
||||
pub use x509_parser::der_parser::num_bigint::*;
|
||||
}
|
||||
|
||||
pub mod x509 {
|
||||
//! Lower-level X.509 types re-exported from
|
||||
//! [`x509_parser`](https://docs.rs/x509-parser/0.13).
|
||||
//!
|
||||
//! Lack of documentation is directly inherited from the source crate.
|
||||
//! Prefer to use Rocket's wrappers when possible.
|
||||
|
||||
pub(crate) use x509_parser::nom;
|
||||
pub use x509_parser::certificate::*;
|
||||
pub use x509_parser::cri_attributes::*;
|
||||
pub use x509_parser::error::*;
|
||||
pub use x509_parser::extensions::*;
|
||||
pub use x509_parser::revocation_list::*;
|
||||
pub use x509_parser::time::*;
|
||||
pub use x509_parser::x509::*;
|
||||
pub use x509_parser::der_parser::der;
|
||||
pub use x509_parser::der_parser::ber;
|
||||
pub use x509_parser::traits::*;
|
||||
}
|
||||
|
||||
mod certificate;
|
||||
mod error;
|
||||
mod name;
|
||||
mod config;
|
||||
|
||||
pub use error::Error;
|
||||
pub use name::Name;
|
||||
pub use config::MtlsConfig;
|
||||
pub use certificate::{Certificate, CertificateDer};
|
||||
|
||||
/// A type alias for [`Result`](std::result::Result) with the error type set to
|
||||
/// [`Error`].
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
|
@ -0,0 +1,146 @@
|
|||
use std::fmt;
|
||||
use std::ops::Deref;
|
||||
|
||||
use ref_cast::RefCast;
|
||||
|
||||
use crate::mtls::x509::X509Name;
|
||||
use crate::mtls::oid;
|
||||
|
||||
/// An X.509 Distinguished Name (DN) found in a [`Certificate`].
|
||||
///
|
||||
/// This type is a wrapper over [`x509::X509Name`] with convenient methods and
|
||||
/// complete documentation. Should the data exposed by the inherent methods not
|
||||
/// suffice, this type derefs to [`x509::X509Name`].
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, PartialEq, RefCast)]
|
||||
pub struct Name<'a>(X509Name<'a>);
|
||||
|
||||
impl<'a> Name<'a> {
|
||||
/// Returns the _first_ UTF-8 _string_ common name, if any.
|
||||
///
|
||||
/// Note that common names need not be UTF-8 strings, or strings at all.
|
||||
/// This method returns the first common name attribute that is.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # #[macro_use] extern crate rocket;
|
||||
/// use rocket::mtls::Certificate;
|
||||
///
|
||||
/// #[get("/auth")]
|
||||
/// fn auth(cert: Certificate<'_>) {
|
||||
/// if let Some(name) = cert.subject().common_name() {
|
||||
/// println!("Hello, {}!", name);
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub fn common_name(&self) -> Option<&'a str> {
|
||||
self.common_names().next()
|
||||
}
|
||||
|
||||
/// Returns an iterator over all of the UTF-8 _string_ common names in
|
||||
/// `self`.
|
||||
///
|
||||
/// Note that common names need not be UTF-8 strings, or strings at all.
|
||||
/// This method filters the common names in `self` to those that are. Use
|
||||
/// the raw [`iter_common_name()`](#method.iter_common_name) to iterate over
|
||||
/// all value types.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # #[macro_use] extern crate rocket;
|
||||
/// use rocket::mtls::Certificate;
|
||||
///
|
||||
/// #[get("/auth")]
|
||||
/// fn auth(cert: Certificate<'_>) {
|
||||
/// for name in cert.issuer().common_names() {
|
||||
/// println!("Issued by {}.", name);
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub fn common_names(&self) -> impl Iterator<Item = &'a str> + '_ {
|
||||
self.iter_by_oid(&oid::OID_X509_COMMON_NAME).filter_map(|n| n.as_str().ok())
|
||||
}
|
||||
|
||||
/// Returns the _first_ UTF-8 _string_ email address, if any.
|
||||
///
|
||||
/// Note that email addresses need not be UTF-8 strings, or strings at all.
|
||||
/// This method returns the first email address attribute that is.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # #[macro_use] extern crate rocket;
|
||||
/// use rocket::mtls::Certificate;
|
||||
///
|
||||
/// #[get("/auth")]
|
||||
/// fn auth(cert: Certificate<'_>) {
|
||||
/// if let Some(email) = cert.subject().email() {
|
||||
/// println!("Hello, {}!", email);
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub fn email(&self) -> Option<&'a str> {
|
||||
self.emails().next()
|
||||
}
|
||||
|
||||
/// Returns an iterator over all of the UTF-8 _string_ email addresses in
|
||||
/// `self`.
|
||||
///
|
||||
/// Note that email addresses need not be UTF-8 strings, or strings at all.
|
||||
/// This method filters the email address in `self` to those that are. Use
|
||||
/// the raw [`iter_email()`](#method.iter_email) to iterate over all value
|
||||
/// types.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # #[macro_use] extern crate rocket;
|
||||
/// use rocket::mtls::Certificate;
|
||||
///
|
||||
/// #[get("/auth")]
|
||||
/// fn auth(cert: Certificate<'_>) {
|
||||
/// for email in cert.subject().emails() {
|
||||
/// println!("Reach me at: {}", email);
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub fn emails(&self) -> impl Iterator<Item = &'a str> + '_ {
|
||||
self.iter_by_oid(&oid::OID_PKCS9_EMAIL_ADDRESS).filter_map(|n| n.as_str().ok())
|
||||
}
|
||||
|
||||
/// Returns `true` if `self` has no data.
|
||||
///
|
||||
/// When this is the case for a `subject()`, the subject data can be found
|
||||
/// in the `subjectAlt` [`extension()`](Certificate::extensions()).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # #[macro_use] extern crate rocket;
|
||||
/// use rocket::mtls::Certificate;
|
||||
///
|
||||
/// #[get("/auth")]
|
||||
/// fn auth(cert: Certificate<'_>) {
|
||||
/// let no_data = cert.subject().is_empty();
|
||||
/// }
|
||||
/// ```
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.0.as_raw().is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Deref for Name<'a> {
|
||||
type Target = X509Name<'a>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Name<'_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
use state::TypeMap;
|
||||
use figment::Figment;
|
||||
|
||||
use crate::listener::Endpoint;
|
||||
use crate::{Catcher, Config, Rocket, Route, Shutdown};
|
||||
use crate::router::Router;
|
||||
use crate::fairing::Fairings;
|
||||
|
@ -113,5 +114,6 @@ phases! {
|
|||
pub(crate) config: Config,
|
||||
pub(crate) state: TypeMap![Send + Sync],
|
||||
pub(crate) shutdown: Shutdown,
|
||||
pub(crate) endpoint: Endpoint,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
use crate::http::Method;
|
||||
|
||||
pub struct AtomicMethod(ref_swap::RefSwap<'static, Method>);
|
||||
|
||||
#[inline(always)]
|
||||
const fn makeref(method: Method) -> &'static Method {
|
||||
match method {
|
||||
Method::Get => &Method::Get,
|
||||
Method::Put => &Method::Put,
|
||||
Method::Post => &Method::Post,
|
||||
Method::Delete => &Method::Delete,
|
||||
Method::Options => &Method::Options,
|
||||
Method::Head => &Method::Head,
|
||||
Method::Trace => &Method::Trace,
|
||||
Method::Connect => &Method::Connect,
|
||||
Method::Patch => &Method::Patch,
|
||||
}
|
||||
}
|
||||
|
||||
impl AtomicMethod {
|
||||
pub fn new(value: Method) -> Self {
|
||||
Self(ref_swap::RefSwap::new(makeref(value)))
|
||||
}
|
||||
|
||||
pub fn load(&self) -> Method {
|
||||
*self.0.load(std::sync::atomic::Ordering::Acquire)
|
||||
}
|
||||
|
||||
pub fn set(&mut self, new: Method) {
|
||||
*self = Self::new(new);
|
||||
}
|
||||
|
||||
pub fn store(&self, new: Method) {
|
||||
self.0.store(makeref(new), std::sync::atomic::Ordering::Release)
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for AtomicMethod {
|
||||
fn clone(&self) -> Self {
|
||||
let inner = self.0.load(std::sync::atomic::Ordering::Acquire);
|
||||
Self(ref_swap::RefSwap::new(inner))
|
||||
}
|
||||
}
|
|
@ -1,12 +1,13 @@
|
|||
use std::convert::Infallible;
|
||||
use std::fmt::Debug;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::net::IpAddr;
|
||||
|
||||
use crate::{Request, Route};
|
||||
use crate::outcome::{self, IntoOutcome, Outcome::*};
|
||||
|
||||
use crate::http::uri::{Host, Origin};
|
||||
use crate::http::{Status, ContentType, Accept, Method, ProxyProto, CookieJar};
|
||||
use crate::listener::Endpoint;
|
||||
|
||||
/// Type alias for the `Outcome` of a `FromRequest` conversion.
|
||||
pub type Outcome<S, E> = outcome::Outcome<S, (Status, E), Status>;
|
||||
|
@ -486,14 +487,22 @@ impl<'r> FromRequest<'r> for ProxyProto<'r> {
|
|||
}
|
||||
|
||||
#[crate::async_trait]
|
||||
impl<'r> FromRequest<'r> for SocketAddr {
|
||||
impl<'r> FromRequest<'r> for &'r Endpoint {
|
||||
type Error = Infallible;
|
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Infallible> {
|
||||
match request.remote() {
|
||||
Some(addr) => Success(addr),
|
||||
None => Forward(Status::InternalServerError)
|
||||
}
|
||||
request.remote().or_forward(Status::InternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
#[crate::async_trait]
|
||||
impl<'r> FromRequest<'r> for std::net::SocketAddr {
|
||||
type Error = Infallible;
|
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Infallible> {
|
||||
request.remote()
|
||||
.and_then(|r| r.tcp())
|
||||
.or_forward(Status::InternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
mod request;
|
||||
mod from_param;
|
||||
mod from_request;
|
||||
mod atomic_method;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
@ -15,6 +16,7 @@ pub use self::from_param::{FromParam, FromSegments};
|
|||
pub use crate::response::flash::FlashMessage;
|
||||
|
||||
pub(crate) use self::request::ConnectionMeta;
|
||||
pub(crate) use self::atomic_method::AtomicMethod;
|
||||
|
||||
crate::export! {
|
||||
/// Store and immediately retrieve a vector-like value `$v` (`String` or
|
||||
|
|
|
@ -1,22 +1,24 @@
|
|||
use std::fmt;
|
||||
use std::ops::RangeFrom;
|
||||
use std::{future::Future, borrow::Cow, sync::Arc};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::{Arc, atomic::Ordering};
|
||||
use std::borrow::Cow;
|
||||
use std::future::Future;
|
||||
use std::net::IpAddr;
|
||||
|
||||
use yansi::Paint;
|
||||
use state::{TypeMap, InitCell};
|
||||
use futures::future::BoxFuture;
|
||||
use atomic::{Atomic, Ordering};
|
||||
use ref_swap::OptionRefSwap;
|
||||
|
||||
use crate::{Rocket, Route, Orbit};
|
||||
use crate::request::{FromParam, FromSegments, FromRequest, Outcome};
|
||||
use crate::request::{FromParam, FromSegments, FromRequest, Outcome, AtomicMethod};
|
||||
use crate::form::{self, ValueField, FromForm};
|
||||
use crate::data::Limits;
|
||||
|
||||
use crate::http::{hyper, Method, Header, HeaderMap, ProxyProto};
|
||||
use crate::http::{ContentType, Accept, MediaType, CookieJar, Cookie};
|
||||
use crate::http::private::Certificates;
|
||||
use crate::http::ProxyProto;
|
||||
use crate::http::{Method, Header, HeaderMap, ContentType, Accept, MediaType, CookieJar, Cookie};
|
||||
use crate::http::uri::{fmt::Path, Origin, Segments, Host, Authority};
|
||||
use crate::listener::{Certificates, Endpoint, Connection};
|
||||
|
||||
/// The type of an incoming web request.
|
||||
///
|
||||
|
@ -24,26 +26,37 @@ use crate::http::uri::{fmt::Path, Origin, Segments, Host, Authority};
|
|||
/// should likely only be used when writing [`FromRequest`] implementations. It
|
||||
/// contains all of the information for a given web request except for the body
|
||||
/// data. This includes the HTTP method, URI, cookies, headers, and more.
|
||||
#[derive(Clone)]
|
||||
pub struct Request<'r> {
|
||||
method: Atomic<Method>,
|
||||
method: AtomicMethod,
|
||||
uri: Origin<'r>,
|
||||
headers: HeaderMap<'r>,
|
||||
pub(crate) errors: Vec<RequestError>,
|
||||
pub(crate) connection: ConnectionMeta,
|
||||
pub(crate) state: RequestState<'r>,
|
||||
}
|
||||
|
||||
/// Information derived from an incoming connection, if any.
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Default)]
|
||||
pub(crate) struct ConnectionMeta {
|
||||
pub remote: Option<SocketAddr>,
|
||||
pub peer_address: Option<Arc<Endpoint>>,
|
||||
#[cfg_attr(not(feature = "mtls"), allow(dead_code))]
|
||||
pub client_certificates: Option<Certificates>,
|
||||
pub peer_certs: Option<Arc<Certificates<'static>>>,
|
||||
}
|
||||
|
||||
impl<C: Connection> From<&C> for ConnectionMeta {
|
||||
fn from(conn: &C) -> Self {
|
||||
ConnectionMeta {
|
||||
peer_address: conn.peer_address().ok().map(Arc::new),
|
||||
peer_certs: conn.peer_certificates().map(|c| c.into_owned()).map(Arc::new),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Information derived from the request.
|
||||
pub(crate) struct RequestState<'r> {
|
||||
pub rocket: &'r Rocket<Orbit>,
|
||||
pub route: Atomic<Option<&'r Route>>,
|
||||
pub route: OptionRefSwap<'r, Route>,
|
||||
pub cookies: CookieJar<'r>,
|
||||
pub accept: InitCell<Option<Accept>>,
|
||||
pub content_type: InitCell<Option<ContentType>>,
|
||||
|
@ -51,23 +64,11 @@ pub(crate) struct RequestState<'r> {
|
|||
pub host: Option<Host<'r>>,
|
||||
}
|
||||
|
||||
impl Request<'_> {
|
||||
pub(crate) fn clone(&self) -> Self {
|
||||
Request {
|
||||
method: Atomic::new(self.method()),
|
||||
uri: self.uri.clone(),
|
||||
headers: self.headers.clone(),
|
||||
connection: self.connection.clone(),
|
||||
state: self.state.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RequestState<'_> {
|
||||
impl Clone for RequestState<'_> {
|
||||
fn clone(&self) -> Self {
|
||||
RequestState {
|
||||
rocket: self.rocket,
|
||||
route: Atomic::new(self.route.load(Ordering::Acquire)),
|
||||
route: OptionRefSwap::new(self.route.load(Ordering::Acquire)),
|
||||
cookies: self.cookies.clone(),
|
||||
accept: self.accept.clone(),
|
||||
content_type: self.content_type.clone(),
|
||||
|
@ -87,15 +88,13 @@ impl<'r> Request<'r> {
|
|||
) -> Request<'r> {
|
||||
Request {
|
||||
uri,
|
||||
method: Atomic::new(method),
|
||||
method: AtomicMethod::new(method),
|
||||
headers: HeaderMap::new(),
|
||||
connection: ConnectionMeta {
|
||||
remote: None,
|
||||
client_certificates: None,
|
||||
},
|
||||
errors: Vec::new(),
|
||||
connection: ConnectionMeta::default(),
|
||||
state: RequestState {
|
||||
rocket,
|
||||
route: Atomic::new(None),
|
||||
route: OptionRefSwap::new(None),
|
||||
cookies: CookieJar::new(None, rocket),
|
||||
accept: InitCell::new(),
|
||||
content_type: InitCell::new(),
|
||||
|
@ -120,7 +119,7 @@ impl<'r> Request<'r> {
|
|||
/// ```
|
||||
#[inline(always)]
|
||||
pub fn method(&self) -> Method {
|
||||
self.method.load(Ordering::Acquire)
|
||||
self.method.load()
|
||||
}
|
||||
|
||||
/// Set the method of `self` to `method`.
|
||||
|
@ -140,7 +139,7 @@ impl<'r> Request<'r> {
|
|||
/// ```
|
||||
#[inline(always)]
|
||||
pub fn set_method(&mut self, method: Method) {
|
||||
self._set_method(method);
|
||||
self.method.set(method);
|
||||
}
|
||||
|
||||
/// Borrow the [`Origin`] URI from `self`.
|
||||
|
@ -324,20 +323,20 @@ impl<'r> Request<'r> {
|
|||
///
|
||||
/// assert_eq!(request.remote(), None);
|
||||
///
|
||||
/// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8000).into();
|
||||
/// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8000);
|
||||
/// request.set_remote(localhost);
|
||||
/// assert_eq!(request.remote(), Some(localhost));
|
||||
/// assert_eq!(request.remote().unwrap(), &localhost);
|
||||
/// ```
|
||||
#[inline(always)]
|
||||
pub fn remote(&self) -> Option<SocketAddr> {
|
||||
self.connection.remote
|
||||
pub fn remote(&self) -> Option<&Endpoint> {
|
||||
self.connection.peer_address.as_deref()
|
||||
}
|
||||
|
||||
/// Sets the remote address of `self` to `address`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// Set the remote address to be 127.0.0.1:8000:
|
||||
/// Set the remote address to be 127.0.0.1:8111:
|
||||
///
|
||||
/// ```rust
|
||||
/// use std::net::{SocketAddrV4, Ipv4Addr};
|
||||
|
@ -347,13 +346,13 @@ impl<'r> Request<'r> {
|
|||
///
|
||||
/// assert_eq!(request.remote(), None);
|
||||
///
|
||||
/// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8000).into();
|
||||
/// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8111);
|
||||
/// request.set_remote(localhost);
|
||||
/// assert_eq!(request.remote(), Some(localhost));
|
||||
/// assert_eq!(request.remote().unwrap(), &localhost);
|
||||
/// ```
|
||||
#[inline(always)]
|
||||
pub fn set_remote(&mut self, address: SocketAddr) {
|
||||
self.connection.remote = Some(address);
|
||||
pub fn set_remote<A: Into<Endpoint>>(&mut self, address: A) {
|
||||
self.connection.peer_address = Some(Arc::new(address.into()));
|
||||
}
|
||||
|
||||
/// Returns the IP address of the configured
|
||||
|
@ -489,25 +488,26 @@ impl<'r> Request<'r> {
|
|||
///
|
||||
/// ```rust
|
||||
/// # use rocket::http::Header;
|
||||
/// # use std::net::{SocketAddr, IpAddr, Ipv4Addr};
|
||||
/// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap();
|
||||
/// # let mut req = c.get("/");
|
||||
/// # let request = req.inner_mut();
|
||||
/// # use std::net::{SocketAddrV4, Ipv4Addr};
|
||||
///
|
||||
/// // starting without an "X-Real-IP" header or remote address
|
||||
/// assert!(request.client_ip().is_none());
|
||||
///
|
||||
/// // add a remote address; this is done by Rocket automatically
|
||||
/// request.set_remote("127.0.0.1:8000".parse().unwrap());
|
||||
/// assert_eq!(request.client_ip(), Some("127.0.0.1".parse().unwrap()));
|
||||
/// let localhost_9190 = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 9190);
|
||||
/// request.set_remote(localhost_9190);
|
||||
/// assert_eq!(request.client_ip().unwrap(), Ipv4Addr::LOCALHOST);
|
||||
///
|
||||
/// // now with an X-Real-IP header, the default value for `ip_header`.
|
||||
/// request.add_header(Header::new("X-Real-IP", "8.8.8.8"));
|
||||
/// assert_eq!(request.client_ip(), Some("8.8.8.8".parse().unwrap()));
|
||||
/// assert_eq!(request.client_ip().unwrap(), Ipv4Addr::new(8, 8, 8, 8));
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn client_ip(&self) -> Option<IpAddr> {
|
||||
self.real_ip().or_else(|| self.remote().map(|r| r.ip()))
|
||||
self.real_ip().or_else(|| Some(self.remote()?.tcp()?.ip()))
|
||||
}
|
||||
|
||||
/// Returns a wrapped borrow to the cookies in `self`.
|
||||
|
@ -691,7 +691,7 @@ impl<'r> Request<'r> {
|
|||
if self.method().supports_payload() {
|
||||
self.content_type().map(|ct| ct.media_type())
|
||||
} else {
|
||||
// FIXME: Should we be using `accept_first` or `preferred`? Or
|
||||
// TODO: Should we be using `accept_first` or `preferred`? Or
|
||||
// should we be checking neither and instead pass things through
|
||||
// where the client accepts the thing at all?
|
||||
self.accept()
|
||||
|
@ -1056,11 +1056,9 @@ impl<'r> Request<'r> {
|
|||
self.state.route.store(Some(route), Ordering::Release)
|
||||
}
|
||||
|
||||
/// Set the method of `self`, even when `self` is a shared reference. Used
|
||||
/// during routing to override methods for re-routing.
|
||||
#[inline(always)]
|
||||
pub(crate) fn _set_method(&self, method: Method) {
|
||||
self.method.store(method, Ordering::Release)
|
||||
self.method.store(method)
|
||||
}
|
||||
|
||||
pub(crate) fn cookies_mut(&mut self) -> &mut CookieJar<'r> {
|
||||
|
@ -1070,18 +1068,28 @@ impl<'r> Request<'r> {
|
|||
/// Convert from Hyper types into a Rocket Request.
|
||||
pub(crate) fn from_hyp(
|
||||
rocket: &'r Rocket<Orbit>,
|
||||
hyper: &'r hyper::request::Parts,
|
||||
connection: Option<ConnectionMeta>,
|
||||
) -> Result<Request<'r>, BadRequest<'r>> {
|
||||
hyper: &'r hyper::http::request::Parts,
|
||||
connection: ConnectionMeta,
|
||||
) -> Result<Request<'r>, Request<'r>> {
|
||||
// Keep track of parsing errors; emit a `BadRequest` if any exist.
|
||||
let mut errors = vec![];
|
||||
|
||||
// Ensure that the method is known. TODO: Allow made-up methods?
|
||||
let method = Method::from_hyp(&hyper.method)
|
||||
.unwrap_or_else(|| {
|
||||
errors.push(Kind::BadMethod(&hyper.method));
|
||||
let method = match hyper.method {
|
||||
hyper::Method::GET => Method::Get,
|
||||
hyper::Method::PUT => Method::Put,
|
||||
hyper::Method::POST => Method::Post,
|
||||
hyper::Method::DELETE => Method::Delete,
|
||||
hyper::Method::OPTIONS => Method::Options,
|
||||
hyper::Method::HEAD => Method::Head,
|
||||
hyper::Method::TRACE => Method::Trace,
|
||||
hyper::Method::CONNECT => Method::Connect,
|
||||
hyper::Method::PATCH => Method::Patch,
|
||||
_ => {
|
||||
errors.push(RequestError::BadMethod(hyper.method.clone()));
|
||||
Method::Get
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: Keep around not just the path/query, but the rest, if there?
|
||||
let uri = hyper.uri.path_and_query()
|
||||
|
@ -1100,20 +1108,20 @@ impl<'r> Request<'r> {
|
|||
Origin::new(uri.path(), uri.query().map(Cow::Borrowed))
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
errors.push(Kind::InvalidUri(&hyper.uri));
|
||||
errors.push(RequestError::InvalidUri(hyper.uri.clone()));
|
||||
Origin::ROOT
|
||||
});
|
||||
|
||||
// Construct the request object; fill in metadata and headers next.
|
||||
let mut request = Request::new(rocket, method, uri);
|
||||
request.errors = errors;
|
||||
|
||||
// Set the passed in connection metadata.
|
||||
if let Some(connection) = connection {
|
||||
request.connection = connection;
|
||||
}
|
||||
request.connection = connection;
|
||||
|
||||
// Determine + set host. On HTTP < 2, use the `HOST` header. Otherwise,
|
||||
// use the `:authority` pseudo-header which hyper makes part of the URI.
|
||||
// TODO: Use an `InitCell` to compute this later.
|
||||
request.state.host = if hyper.version < hyper::Version::HTTP_2 {
|
||||
hyper.headers.get("host").and_then(|h| Host::parse_bytes(h.as_bytes()).ok())
|
||||
} else {
|
||||
|
@ -1122,9 +1130,8 @@ impl<'r> Request<'r> {
|
|||
|
||||
// Set the request cookies, if they exist.
|
||||
for header in hyper.headers.get_all("Cookie") {
|
||||
let raw_str = match std::str::from_utf8(header.as_bytes()) {
|
||||
Ok(string) => string,
|
||||
Err(_) => continue
|
||||
let Ok(raw_str) = std::str::from_utf8(header.as_bytes()) else {
|
||||
continue
|
||||
};
|
||||
|
||||
for cookie_str in raw_str.split(';').map(|s| s.trim()) {
|
||||
|
@ -1137,43 +1144,33 @@ impl<'r> Request<'r> {
|
|||
// Set the rest of the headers. This is rather unfortunate and slow.
|
||||
for (name, value) in hyper.headers.iter() {
|
||||
// FIXME: This is rather unfortunate. Header values needn't be UTF8.
|
||||
let value = match std::str::from_utf8(value.as_bytes()) {
|
||||
Ok(value) => value,
|
||||
Err(_) => {
|
||||
warn!("Header '{}' contains invalid UTF-8", name);
|
||||
warn_!("Rocket only supports UTF-8 header values. Dropping header.");
|
||||
continue;
|
||||
}
|
||||
let Ok(value) = std::str::from_utf8(value.as_bytes()) else {
|
||||
warn!("Header '{}' contains invalid UTF-8", name);
|
||||
warn_!("Rocket only supports UTF-8 header values. Dropping header.");
|
||||
continue;
|
||||
};
|
||||
|
||||
request.add_header(Header::new(name.as_str(), value));
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
Ok(request)
|
||||
} else {
|
||||
Err(BadRequest { request, errors })
|
||||
match request.errors.is_empty() {
|
||||
true => Ok(request),
|
||||
false => Err(request),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct BadRequest<'r> {
|
||||
pub request: Request<'r>,
|
||||
pub errors: Vec<Kind<'r>>,
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum RequestError {
|
||||
InvalidUri(hyper::Uri),
|
||||
BadMethod(hyper::Method),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum Kind<'r> {
|
||||
InvalidUri(&'r hyper::Uri),
|
||||
BadMethod(&'r hyper::Method),
|
||||
}
|
||||
|
||||
impl fmt::Display for Kind<'_> {
|
||||
impl fmt::Display for RequestError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Kind::InvalidUri(u) => write!(f, "invalid origin URI: {}", u),
|
||||
Kind::BadMethod(m) => write!(f, "invalid or unrecognized method: {}", m),
|
||||
RequestError::InvalidUri(u) => write!(f, "invalid origin URI: {}", u),
|
||||
RequestError::BadMethod(m) => write!(f, "invalid or unrecognized method: {}", m),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1181,8 +1178,8 @@ impl fmt::Display for Kind<'_> {
|
|||
impl fmt::Debug for Request<'_> {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt.debug_struct("Request")
|
||||
.field("method", &self.method)
|
||||
.field("uri", &self.uri)
|
||||
.field("method", &self.method())
|
||||
.field("uri", &self.uri())
|
||||
.field("headers", &self.headers())
|
||||
.field("remote", &self.remote())
|
||||
.field("cookies", &self.cookies())
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use crate::Request;
|
||||
use crate::request::{Request, ConnectionMeta};
|
||||
use crate::local::blocking::Client;
|
||||
use crate::http::hyper;
|
||||
|
||||
macro_rules! assert_headers {
|
||||
($($key:expr => [$($value:expr),+]),+) => ({
|
||||
// Create a new Hyper request. Add all of the passed in headers.
|
||||
let mut req = hyper::Request::get("/test").body(()).unwrap();
|
||||
$($(req.headers_mut().append($key, hyper::HeaderValue::from_str($value).unwrap());)+)+
|
||||
$($(
|
||||
req.headers_mut()
|
||||
.append($key, hyper::header::HeaderValue::from_str($value).unwrap());
|
||||
)+)+
|
||||
|
||||
// Build up what we expect the headers to actually be.
|
||||
let mut expected = HashMap::new();
|
||||
|
@ -17,7 +19,8 @@ macro_rules! assert_headers {
|
|||
// Create a valid `Rocket` and convert the hyper req to a Rocket one.
|
||||
let client = Client::debug_with(vec![]).unwrap();
|
||||
let hyper = req.into_parts().0;
|
||||
let req = Request::from_hyp(client.rocket(), &hyper, None).unwrap();
|
||||
let meta = ConnectionMeta::default();
|
||||
let req = Request::from_hyp(client.rocket(), &hyper, meta).unwrap();
|
||||
|
||||
// Dispatch the request and check that the headers match.
|
||||
let actual_headers = req.headers();
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use std::{fmt, str};
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncSeek};
|
||||
|
||||
|
@ -146,19 +145,18 @@ impl<'r> Builder<'r> {
|
|||
/// potentially different values to be present in the `Response`.
|
||||
///
|
||||
/// The type of `header` can be any type that implements `Into<Header>`.
|
||||
/// This includes `Header` itself, [`ContentType`](crate::http::ContentType) and
|
||||
/// [hyper::header types](crate::http::hyper::header).
|
||||
/// This includes `Header` itself, [`ContentType`](crate::http::ContentType)
|
||||
/// and [`Accept`](crate::http::Accept).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use rocket::Response;
|
||||
/// use rocket::http::Header;
|
||||
/// use rocket::http::hyper::header::ACCEPT;
|
||||
/// use rocket::http::{Header, Accept};
|
||||
///
|
||||
/// let response = Response::build()
|
||||
/// .header_adjoin(Header::new(ACCEPT.as_str(), "application/json"))
|
||||
/// .header_adjoin(Header::new(ACCEPT.as_str(), "text/plain"))
|
||||
/// .header_adjoin(Header::new("Accept", "application/json"))
|
||||
/// .header_adjoin(Accept::XML)
|
||||
/// .finalize();
|
||||
///
|
||||
/// assert_eq!(response.headers().get("Accept").count(), 2);
|
||||
|
@ -287,7 +285,7 @@ impl<'r> Builder<'r> {
|
|||
///
|
||||
/// #[rocket::async_trait]
|
||||
/// impl IoHandler for EchoHandler {
|
||||
/// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
|
||||
/// async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
|
||||
/// let (mut reader, mut writer) = io::split(io);
|
||||
/// io::copy(&mut reader, &mut writer).await?;
|
||||
/// Ok(())
|
||||
|
@ -488,7 +486,7 @@ pub struct Response<'r> {
|
|||
status: Option<Status>,
|
||||
headers: HeaderMap<'r>,
|
||||
body: Body<'r>,
|
||||
upgrade: HashMap<Uncased<'r>, Pin<Box<dyn IoHandler + 'r>>>,
|
||||
upgrade: HashMap<Uncased<'r>, Box<dyn IoHandler + 'r>>,
|
||||
}
|
||||
|
||||
impl<'r> Response<'r> {
|
||||
|
@ -700,23 +698,22 @@ impl<'r> Response<'r> {
|
|||
/// name `header.name`, another header with the same name and value
|
||||
/// `header.value` is added. The type of `header` can be any type that
|
||||
/// implements `Into<Header>`. This includes `Header` itself,
|
||||
/// [`ContentType`](crate::http::ContentType) and [`hyper::header`
|
||||
/// types](crate::http::hyper::header).
|
||||
/// [`ContentType`](crate::http::ContentType),
|
||||
/// [`Accept`](crate::http::Accept).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use rocket::Response;
|
||||
/// use rocket::http::Header;
|
||||
/// use rocket::http::hyper::header::ACCEPT;
|
||||
/// use rocket::http::{Header, Accept};
|
||||
///
|
||||
/// let mut response = Response::new();
|
||||
/// response.adjoin_header(Header::new(ACCEPT.as_str(), "application/json"));
|
||||
/// response.adjoin_header(Header::new(ACCEPT.as_str(), "text/plain"));
|
||||
/// response.adjoin_header(Accept::JSON);
|
||||
/// response.adjoin_header(Header::new("Accept", "text/plain"));
|
||||
///
|
||||
/// let mut accept_headers = response.headers().iter();
|
||||
/// assert_eq!(accept_headers.next(), Some(Header::new(ACCEPT.as_str(), "application/json")));
|
||||
/// assert_eq!(accept_headers.next(), Some(Header::new(ACCEPT.as_str(), "text/plain")));
|
||||
/// assert_eq!(accept_headers.next(), Some(Header::new("Accept", "application/json")));
|
||||
/// assert_eq!(accept_headers.next(), Some(Header::new("Accept", "text/plain")));
|
||||
/// assert_eq!(accept_headers.next(), None);
|
||||
/// ```
|
||||
#[inline(always)]
|
||||
|
@ -801,10 +798,10 @@ impl<'r> Response<'r> {
|
|||
/// the comma-separated protocols any of the strings in `I`. Returns
|
||||
/// `Ok(None)` if `self` doesn't support any kind of upgrade. Returns
|
||||
/// `Err(_)` if `protocols` is non-empty but no match was found in `self`.
|
||||
pub(crate) fn take_upgrade<I: Iterator<Item = &'r str>>(
|
||||
pub(crate) fn search_upgrades<'a, I: Iterator<Item = &'a str>>(
|
||||
&mut self,
|
||||
protocols: I
|
||||
) -> Result<Option<(Uncased<'r>, Pin<Box<dyn IoHandler + 'r>>)>, ()> {
|
||||
) -> Result<Option<(Uncased<'r>, Box<dyn IoHandler + 'r>)>, ()> {
|
||||
if self.upgrade.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
@ -839,7 +836,7 @@ impl<'r> Response<'r> {
|
|||
///
|
||||
/// #[rocket::async_trait]
|
||||
/// impl IoHandler for EchoHandler {
|
||||
/// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
|
||||
/// async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
|
||||
/// let (mut reader, mut writer) = io::split(io);
|
||||
/// io::copy(&mut reader, &mut writer).await?;
|
||||
/// Ok(())
|
||||
|
@ -854,7 +851,7 @@ impl<'r> Response<'r> {
|
|||
/// assert!(response.upgrade("raw-echo").is_some());
|
||||
/// # })
|
||||
/// ```
|
||||
pub fn upgrade(&mut self, proto: &str) -> Option<Pin<&mut (dyn IoHandler + 'r)>> {
|
||||
pub fn upgrade(&mut self, proto: &str) -> Option<&mut (dyn IoHandler + 'r)> {
|
||||
self.upgrade.get_mut(proto.as_uncased()).map(|h| h.as_mut())
|
||||
}
|
||||
|
||||
|
@ -972,7 +969,7 @@ impl<'r> Response<'r> {
|
|||
///
|
||||
/// #[rocket::async_trait]
|
||||
/// impl IoHandler for EchoHandler {
|
||||
/// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
|
||||
/// async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
|
||||
/// let (mut reader, mut writer) = io::split(io);
|
||||
/// io::copy(&mut reader, &mut writer).await?;
|
||||
/// Ok(())
|
||||
|
@ -990,7 +987,7 @@ impl<'r> Response<'r> {
|
|||
pub fn add_upgrade<N, H>(&mut self, protocol: N, handler: H)
|
||||
where N: Into<Uncased<'r>>, H: IoHandler + 'r
|
||||
{
|
||||
self.upgrade.insert(protocol.into(), Box::pin(handler));
|
||||
self.upgrade.insert(protocol.into(), Box::new(handler));
|
||||
}
|
||||
|
||||
/// Sets the body's maximum chunk size to `size` bytes.
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use std::borrow::Cow;
|
||||
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::time::Duration;
|
||||
use futures::stream::{self, Stream, StreamExt};
|
||||
use futures::future::ready;
|
||||
use tokio::time::{interval, Duration};
|
||||
use futures::{stream::{self, Stream}, future::Either};
|
||||
use tokio_stream::{StreamExt, wrappers::IntervalStream};
|
||||
|
||||
use crate::request::Request;
|
||||
use crate::response::{self, Response, Responder, stream::{ReaderStream, RawLinedEvent}};
|
||||
|
@ -336,7 +336,7 @@ impl Event {
|
|||
Some(RawLinedEvent::raw("")),
|
||||
];
|
||||
|
||||
stream::iter(events).filter_map(ready)
|
||||
stream::iter(events).filter_map(|x| x)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -528,25 +528,19 @@ impl<S: Stream<Item = Event>> EventStream<S> {
|
|||
self
|
||||
}
|
||||
|
||||
fn heartbeat_stream(&self) -> Option<impl Stream<Item = RawLinedEvent>> {
|
||||
use tokio::time::interval;
|
||||
use tokio_stream::wrappers::IntervalStream;
|
||||
|
||||
fn heartbeat_stream(&self) -> impl Stream<Item = RawLinedEvent> {
|
||||
self.heartbeat
|
||||
.map(|beat| IntervalStream::new(interval(beat)))
|
||||
.map(|stream| stream.map(|_| RawLinedEvent::raw(":")))
|
||||
.map_or_else(|| Either::Right(stream::empty()), Either::Left)
|
||||
}
|
||||
|
||||
fn into_stream(self) -> impl Stream<Item = RawLinedEvent> {
|
||||
use futures::future::Either;
|
||||
use crate::ext::StreamExt;
|
||||
use futures::StreamExt;
|
||||
|
||||
let heartbeat_stream = self.heartbeat_stream();
|
||||
let raw_events = self.stream.map(|e| e.into_stream()).flatten();
|
||||
match heartbeat_stream {
|
||||
Some(heartbeat) => Either::Left(raw_events.join(heartbeat)),
|
||||
None => Either::Right(raw_events)
|
||||
}
|
||||
let heartbeats = self.heartbeat_stream();
|
||||
let events = StreamExt::map(self.stream, |e| e.into_stream()).flatten();
|
||||
crate::util::join(events, heartbeats)
|
||||
}
|
||||
|
||||
fn into_reader(self) -> impl AsyncRead {
|
||||
|
@ -621,10 +615,11 @@ mod sse_tests {
|
|||
|
||||
impl<S: Stream<Item = Event>> EventStream<S> {
|
||||
fn into_string(self) -> String {
|
||||
use std::pin::pin;
|
||||
|
||||
crate::async_test(async move {
|
||||
let mut string = String::new();
|
||||
let reader = self.into_reader();
|
||||
tokio::pin!(reader);
|
||||
let mut reader = pin!(self.into_reader());
|
||||
reader.read_to_string(&mut string).await.expect("event stream -> string");
|
||||
string
|
||||
})
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
use std::fmt;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use yansi::Paint;
|
||||
use either::Either;
|
||||
use figment::{Figment, Provider};
|
||||
|
||||
use crate::{Catcher, Config, Route, Shutdown, sentinel, shield::Shield};
|
||||
use crate::listener::{Endpoint, Bindable, DefaultListener};
|
||||
use crate::router::Router;
|
||||
use crate::trip_wire::TripWire;
|
||||
use crate::util::TripWire;
|
||||
use crate::fairing::{Fairing, Fairings};
|
||||
use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting};
|
||||
use crate::phase::{Stateful, StateRef, State};
|
||||
|
@ -203,35 +203,31 @@ impl Rocket<Build> {
|
|||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use rocket::Config;
|
||||
/// use rocket::config::{Config, Ident};
|
||||
/// # use std::net::Ipv4Addr;
|
||||
/// # use std::path::{Path, PathBuf};
|
||||
/// # type Result = std::result::Result<(), rocket::Error>;
|
||||
///
|
||||
/// let config = Config {
|
||||
/// port: 7777,
|
||||
/// address: Ipv4Addr::new(18, 127, 0, 1).into(),
|
||||
/// ident: Ident::try_new("MyServer").expect("valid ident"),
|
||||
/// temp_dir: "/tmp/config-example".into(),
|
||||
/// ..Config::debug_default()
|
||||
/// };
|
||||
///
|
||||
/// # let _: Result = rocket::async_test(async move {
|
||||
/// let rocket = rocket::custom(&config).ignite().await?;
|
||||
/// assert_eq!(rocket.config().port, 7777);
|
||||
/// assert_eq!(rocket.config().address, Ipv4Addr::new(18, 127, 0, 1));
|
||||
/// assert_eq!(rocket.config().ident.as_str(), Some("MyServer"));
|
||||
/// assert_eq!(rocket.config().temp_dir.relative(), Path::new("/tmp/config-example"));
|
||||
///
|
||||
/// // Create a new figment which modifies _some_ keys the existing figment:
|
||||
/// let figment = rocket.figment().clone()
|
||||
/// .merge((Config::PORT, 8888))
|
||||
/// .merge((Config::ADDRESS, "171.64.200.10"));
|
||||
/// .merge((Config::IDENT, "Example"));
|
||||
///
|
||||
/// let rocket = rocket::custom(&config)
|
||||
/// .configure(figment)
|
||||
/// .ignite().await?;
|
||||
///
|
||||
/// assert_eq!(rocket.config().port, 8888);
|
||||
/// assert_eq!(rocket.config().address, Ipv4Addr::new(171, 64, 200, 10));
|
||||
/// assert_eq!(rocket.config().ident.as_str(), Some("Example"));
|
||||
/// assert_eq!(rocket.config().temp_dir.relative(), Path::new("/tmp/config-example"));
|
||||
/// # Ok(())
|
||||
/// # });
|
||||
|
@ -664,8 +660,9 @@ impl Rocket<Ignite> {
|
|||
self.shutdown.clone()
|
||||
}
|
||||
|
||||
fn into_orbit(self) -> Rocket<Orbit> {
|
||||
pub(crate) fn into_orbit(self, address: Endpoint) -> Rocket<Orbit> {
|
||||
Rocket(Orbiting {
|
||||
endpoint: address,
|
||||
router: self.0.router,
|
||||
fairings: self.0.fairings,
|
||||
figment: self.0.figment,
|
||||
|
@ -675,28 +672,24 @@ impl Rocket<Ignite> {
|
|||
})
|
||||
}
|
||||
|
||||
async fn _local_launch(self) -> Rocket<Orbit> {
|
||||
let rocket = self.into_orbit();
|
||||
rocket.fairings.handle_liftoff(&rocket).await;
|
||||
launch_info!("{}{}", "🚀 ".emoji(), "Rocket has launched locally".primary().bold());
|
||||
async fn _local_launch(self, addr: Endpoint) -> Rocket<Orbit> {
|
||||
let rocket = self.into_orbit(addr);
|
||||
Rocket::liftoff(&rocket).await;
|
||||
rocket
|
||||
}
|
||||
|
||||
async fn _launch(self) -> Result<Rocket<Ignite>, Error> {
|
||||
self.into_orbit()
|
||||
.default_tcp_http_server(|rkt| Box::pin(async move {
|
||||
rkt.fairings.handle_liftoff(&rkt).await;
|
||||
let config = self.figment().extract::<DefaultListener>()?;
|
||||
either::for_both!(config.base_bindable()?, base => {
|
||||
either::for_both!(config.tls_bindable(base), bindable => {
|
||||
self._launch_on(bindable).await
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
let proto = rkt.config.tls_enabled().then(|| "https").unwrap_or("http");
|
||||
let socket_addr = SocketAddr::new(rkt.config.address, rkt.config.port);
|
||||
let addr = format!("{}://{}", proto, socket_addr);
|
||||
launch_info!("{}{} {}",
|
||||
"🚀 ".emoji(),
|
||||
"Rocket has launched from".bold().primary().linger(),
|
||||
addr.underline());
|
||||
}))
|
||||
.await
|
||||
.map(|rocket| rocket.into_ignite())
|
||||
async fn _launch_on<B: Bindable>(self, bindable: B) -> Result<Rocket<Ignite>, Error> {
|
||||
let listener = bindable.bind().await.map_err(|e| ErrorKind::Bind(Box::new(e)))?;
|
||||
self.serve(listener).await
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -712,6 +705,21 @@ impl Rocket<Orbit> {
|
|||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn liftoff<R: Deref<Target = Self>>(rocket: R) {
|
||||
let rocket = rocket.deref();
|
||||
rocket.fairings.handle_liftoff(rocket).await;
|
||||
|
||||
if !crate::running_within_rocket_async_rt().await {
|
||||
warn!("Rocket is executing inside of a custom runtime.");
|
||||
info_!("Rocket's runtime is enabled via `#[rocket::main]` or `#[launch]`.");
|
||||
info_!("Forced shutdown is disabled. Runtime settings may be suboptimal.");
|
||||
}
|
||||
|
||||
launch_info!("{}{} {}", "🚀 ".emoji(),
|
||||
"Rocket has launched on".bold().primary().linger(),
|
||||
rocket.endpoint().underline());
|
||||
}
|
||||
|
||||
/// Returns the finalized, active configuration. This is guaranteed to
|
||||
/// remain stable after [`Rocket::ignite()`], through ignition and into
|
||||
/// orbit.
|
||||
|
@ -734,6 +742,10 @@ impl Rocket<Orbit> {
|
|||
&self.config
|
||||
}
|
||||
|
||||
pub fn endpoint(&self) -> &Endpoint {
|
||||
&self.endpoint
|
||||
}
|
||||
|
||||
/// Returns a handle which can be used to trigger a shutdown and detect a
|
||||
/// triggered shutdown.
|
||||
///
|
||||
|
@ -867,10 +879,10 @@ impl<P: Phase> Rocket<P> {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn local_launch(self) -> Result<Rocket<Orbit>, Error> {
|
||||
pub(crate) async fn local_launch(self, l: Endpoint) -> Result<Rocket<Orbit>, Error> {
|
||||
let rocket = match self.0.into_state() {
|
||||
State::Build(s) => Rocket::from(s).ignite().await?._local_launch().await,
|
||||
State::Ignite(s) => Rocket::from(s)._local_launch().await,
|
||||
State::Build(s) => Rocket::from(s).ignite().await?._local_launch(l).await,
|
||||
State::Ignite(s) => Rocket::from(s)._local_launch(l).await,
|
||||
State::Orbit(s) => Rocket::from(s)
|
||||
};
|
||||
|
||||
|
@ -928,6 +940,14 @@ impl<P: Phase> Rocket<P> {
|
|||
State::Orbit(s) => Ok(Rocket::from(s).into_ignite())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn launch_on<B: Bindable>(self, bindable: B) -> Result<Rocket<Ignite>, Error> {
|
||||
match self.0.into_state() {
|
||||
State::Build(s) => Rocket::from(s).ignite().await?._launch_on(bindable).await,
|
||||
State::Ignite(s) => Rocket::from(s)._launch_on(bindable).await,
|
||||
State::Orbit(s) => Ok(Rocket::from(s).into_ignite())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
|
|
|
@ -167,7 +167,6 @@ impl<F: Clone + Sync + Send + 'static> Handler for F
|
|||
}
|
||||
}
|
||||
|
||||
// FIXME!
|
||||
impl<'r, 'o: 'r> Outcome<'o> {
|
||||
/// Return the `Outcome` of response to `req` from `responder`.
|
||||
///
|
||||
|
|
|
@ -1,540 +1,142 @@
|
|||
use std::io;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::pin::Pin;
|
||||
|
||||
use yansi::Paint;
|
||||
use tokio::sync::oneshot;
|
||||
use hyper::service::service_fn;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||
use hyper_util::server::conn::auto::Builder;
|
||||
use futures::{Future, TryFutureExt, future::{select, Either::*}};
|
||||
use tokio::time::sleep;
|
||||
use futures::stream::StreamExt;
|
||||
use futures::future::{FutureExt, Future, BoxFuture};
|
||||
|
||||
use crate::{route, Rocket, Orbit, Request, Response, Data, Config};
|
||||
use crate::form::Form;
|
||||
use crate::outcome::Outcome;
|
||||
use crate::error::{Error, ErrorKind};
|
||||
use crate::ext::{AsyncReadExt, CancellableListener, CancellableIo};
|
||||
use crate::{Request, Rocket, Orbit, Data, Ignite};
|
||||
use crate::request::ConnectionMeta;
|
||||
use crate::data::IoHandler;
|
||||
|
||||
use crate::http::{hyper, uncased, Method, Status, Header};
|
||||
use crate::http::private::{TcpListener, Listener, Connection, Incoming};
|
||||
|
||||
// A token returned to force the execution of one method before another.
|
||||
pub(crate) struct RequestToken;
|
||||
|
||||
async fn handle<Fut, T, F>(name: Option<&str>, run: F) -> Option<T>
|
||||
where F: FnOnce() -> Fut, Fut: Future<Output = T>,
|
||||
{
|
||||
use std::panic::AssertUnwindSafe;
|
||||
|
||||
macro_rules! panic_info {
|
||||
($name:expr, $e:expr) => {{
|
||||
match $name {
|
||||
Some(name) => error_!("Handler {} panicked.", name.primary()),
|
||||
None => error_!("A handler panicked.")
|
||||
};
|
||||
|
||||
info_!("This is an application bug.");
|
||||
info_!("A panic in Rust must be treated as an exceptional event.");
|
||||
info_!("Panicking is not a suitable error handling mechanism.");
|
||||
info_!("Unwinding, the result of a panic, is an expensive operation.");
|
||||
info_!("Panics will degrade application performance.");
|
||||
info_!("Instead of panicking, return `Option` and/or `Result`.");
|
||||
info_!("Values of either type can be returned directly from handlers.");
|
||||
warn_!("A panic is treated as an internal server error.");
|
||||
$e
|
||||
}}
|
||||
}
|
||||
|
||||
let run = AssertUnwindSafe(run);
|
||||
let fut = std::panic::catch_unwind(move || run())
|
||||
.map_err(|e| panic_info!(name, e))
|
||||
.ok()?;
|
||||
|
||||
AssertUnwindSafe(fut)
|
||||
.catch_unwind()
|
||||
.await
|
||||
.map_err(|e| panic_info!(name, e))
|
||||
.ok()
|
||||
}
|
||||
|
||||
// This function tries to hide all of the Hyper-ness from Rocket. It essentially
|
||||
// converts Hyper types into Rocket types, then calls the `dispatch` function,
|
||||
// which knows nothing about Hyper. Because responding depends on the
|
||||
// `HyperResponse` type, this function does the actual response processing.
|
||||
async fn hyper_service_fn(
|
||||
rocket: Arc<Rocket<Orbit>>,
|
||||
conn: ConnectionMeta,
|
||||
mut hyp_req: hyper::Request<hyper::Body>,
|
||||
) -> Result<hyper::Response<hyper::Body>, io::Error> {
|
||||
// This future must return a hyper::Response, but the response body might
|
||||
// borrow from the request. Instead, write the body in another future that
|
||||
// sends the response metadata (and a body channel) prior.
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
#[cfg(not(broken_fmt))]
|
||||
debug!("received request: {:#?}", hyp_req);
|
||||
|
||||
tokio::spawn(async move {
|
||||
// We move the request next, so get the upgrade future now.
|
||||
let pending_upgrade = hyper::upgrade::on(&mut hyp_req);
|
||||
|
||||
// Convert a Hyper request into a Rocket request.
|
||||
let (h_parts, mut h_body) = hyp_req.into_parts();
|
||||
match Request::from_hyp(&rocket, &h_parts, Some(conn)) {
|
||||
Ok(mut req) => {
|
||||
// Convert into Rocket `Data`, dispatch request, write response.
|
||||
let mut data = Data::from(&mut h_body);
|
||||
let token = rocket.preprocess_request(&mut req, &mut data).await;
|
||||
let mut response = rocket.dispatch(token, &req, data).await;
|
||||
let upgrade = response.take_upgrade(req.headers().get("upgrade"));
|
||||
if let Ok(Some((proto, handler))) = upgrade {
|
||||
rocket.handle_upgrade(response, proto, handler, pending_upgrade, tx).await;
|
||||
} else {
|
||||
if upgrade.is_err() {
|
||||
warn_!("Request wants upgrade but no I/O handler matched.");
|
||||
info_!("Request is not being upgraded.");
|
||||
}
|
||||
|
||||
rocket.send_response(response, tx).await;
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("Bad incoming HTTP request.");
|
||||
e.errors.iter().for_each(|e| warn_!("Error: {}.", e));
|
||||
warn_!("Dispatching salvaged request to catcher: {}.", e.request);
|
||||
|
||||
let response = rocket.handle_error(Status::BadRequest, &e.request).await;
|
||||
rocket.send_response(response, tx).await;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Receive the response written to `tx` by the task above.
|
||||
rx.await.map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))
|
||||
}
|
||||
use crate::erased::{ErasedRequest, ErasedResponse, ErasedIoHandler};
|
||||
use crate::listener::{Listener, CancellableExt, BouncedExt};
|
||||
use crate::error::{Error, ErrorKind};
|
||||
use crate::data::IoStream;
|
||||
use crate::util::ReaderStream;
|
||||
use crate::http::Status;
|
||||
|
||||
impl Rocket<Orbit> {
|
||||
/// Wrapper around `_send_response` to log a success or error.
|
||||
#[inline]
|
||||
async fn send_response(
|
||||
&self,
|
||||
response: Response<'_>,
|
||||
tx: oneshot::Sender<hyper::Response<hyper::Body>>,
|
||||
) {
|
||||
let remote_hungup = |e: &io::Error| match e.kind() {
|
||||
| io::ErrorKind::BrokenPipe
|
||||
| io::ErrorKind::ConnectionReset
|
||||
| io::ErrorKind::ConnectionAborted => true,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
match self._send_response(response, tx).await {
|
||||
Ok(()) => info_!("{}", "Response succeeded.".green()),
|
||||
Err(e) if remote_hungup(&e) => warn_!("Remote left: {}.", e),
|
||||
Err(e) => warn_!("Failed to write response: {}.", e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempts to create a hyper response from `response` and send it to `tx`.
|
||||
#[inline]
|
||||
async fn _send_response(
|
||||
&self,
|
||||
mut response: Response<'_>,
|
||||
tx: oneshot::Sender<hyper::Response<hyper::Body>>,
|
||||
) -> io::Result<()> {
|
||||
let mut hyp_res = hyper::Response::builder();
|
||||
|
||||
hyp_res = hyp_res.status(response.status().code);
|
||||
for header in response.headers().iter() {
|
||||
let name = header.name.as_str();
|
||||
let value = header.value.as_bytes();
|
||||
hyp_res = hyp_res.header(name, value);
|
||||
}
|
||||
|
||||
let body = response.body_mut();
|
||||
if let Some(n) = body.size().await {
|
||||
hyp_res = hyp_res.header(hyper::header::CONTENT_LENGTH, n);
|
||||
}
|
||||
|
||||
let (mut sender, hyp_body) = hyper::Body::channel();
|
||||
let hyp_response = hyp_res.body(hyp_body)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
#[cfg(not(broken_fmt))]
|
||||
debug!("sending response: {:#?}", hyp_response);
|
||||
|
||||
tx.send(hyp_response).map_err(|_| {
|
||||
let msg = "client disconnect before response started";
|
||||
io::Error::new(io::ErrorKind::BrokenPipe, msg)
|
||||
})?;
|
||||
|
||||
let max_chunk_size = body.max_chunk_size();
|
||||
let mut stream = body.into_bytes_stream(max_chunk_size);
|
||||
while let Some(next) = stream.next().await {
|
||||
sender.send_data(next?).await
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_upgrade<'r>(
|
||||
&self,
|
||||
mut response: Response<'r>,
|
||||
proto: uncased::Uncased<'r>,
|
||||
io_handler: Pin<Box<dyn IoHandler + 'r>>,
|
||||
pending_upgrade: hyper::upgrade::OnUpgrade,
|
||||
tx: oneshot::Sender<hyper::Response<hyper::Body>>,
|
||||
) {
|
||||
info_!("Upgrading connection to {}.", Paint::white(&proto).bold());
|
||||
response.set_status(Status::SwitchingProtocols);
|
||||
response.set_raw_header("Connection", "Upgrade");
|
||||
response.set_raw_header("Upgrade", proto.clone().into_cow());
|
||||
self.send_response(response, tx).await;
|
||||
|
||||
match pending_upgrade.await {
|
||||
Ok(io_stream) => {
|
||||
info_!("Upgrade successful.");
|
||||
if let Err(e) = io_handler.io(io_stream.into()).await {
|
||||
if e.kind() == io::ErrorKind::BrokenPipe {
|
||||
warn!("Upgraded {} I/O handler was closed.", proto);
|
||||
} else {
|
||||
error!("Upgraded {} I/O handler failed: {}", proto, e);
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("Response indicated upgrade, but upgrade failed.");
|
||||
warn_!("Upgrade error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Preprocess the request for Rocket things. Currently, this means:
|
||||
///
|
||||
/// * Rewriting the method in the request if _method form field exists.
|
||||
/// * Run the request fairings.
|
||||
///
|
||||
/// Keep this in-sync with derive_form when preprocessing form fields.
|
||||
pub(crate) async fn preprocess_request(
|
||||
&self,
|
||||
req: &mut Request<'_>,
|
||||
data: &mut Data<'_>
|
||||
) -> RequestToken {
|
||||
// Check if this is a form and if the form contains the special _method
|
||||
// field which we use to reinterpret the request's method.
|
||||
let (min_len, max_len) = ("_method=get".len(), "_method=delete".len());
|
||||
let peek_buffer = data.peek(max_len).await;
|
||||
let is_form = req.content_type().map_or(false, |ct| ct.is_form());
|
||||
|
||||
if is_form && req.method() == Method::Post && peek_buffer.len() >= min_len {
|
||||
let method = std::str::from_utf8(peek_buffer).ok()
|
||||
.and_then(|raw_form| Form::values(raw_form).next())
|
||||
.filter(|field| field.name == "_method")
|
||||
.and_then(|field| field.value.parse().ok());
|
||||
|
||||
if let Some(method) = method {
|
||||
req._set_method(method);
|
||||
}
|
||||
}
|
||||
|
||||
// Run request fairings.
|
||||
self.fairings.handle_request(req, data).await;
|
||||
|
||||
RequestToken
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) async fn dispatch<'s, 'r: 's>(
|
||||
&'s self,
|
||||
_token: RequestToken,
|
||||
request: &'r Request<'s>,
|
||||
data: Data<'r>
|
||||
) -> Response<'r> {
|
||||
info!("{}:", request);
|
||||
|
||||
// Remember if the request is `HEAD` for later body stripping.
|
||||
let was_head_request = request.method() == Method::Head;
|
||||
|
||||
// Route the request and run the user's handlers.
|
||||
let mut response = self.route_and_process(request, data).await;
|
||||
|
||||
// Add a default 'Server' header if it isn't already there.
|
||||
// TODO: If removing Hyper, write out `Date` header too.
|
||||
if let Some(ident) = request.rocket().config.ident.as_str() {
|
||||
if !response.headers().contains("Server") {
|
||||
response.set_header(Header::new("Server", ident));
|
||||
}
|
||||
}
|
||||
|
||||
// Run the response fairings.
|
||||
self.fairings.handle_response(request, &mut response).await;
|
||||
|
||||
// Strip the body if this is a `HEAD` request.
|
||||
if was_head_request {
|
||||
response.strip_body();
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
async fn route_and_process<'s, 'r: 's>(
|
||||
&'s self,
|
||||
request: &'r Request<'s>,
|
||||
data: Data<'r>
|
||||
) -> Response<'r> {
|
||||
let mut response = match self.route(request, data).await {
|
||||
Outcome::Success(response) => response,
|
||||
Outcome::Forward((data, _)) if request.method() == Method::Head => {
|
||||
info_!("Autohandling {} request.", "HEAD".primary().bold());
|
||||
|
||||
// Dispatch the request again with Method `GET`.
|
||||
request._set_method(Method::Get);
|
||||
match self.route(request, data).await {
|
||||
Outcome::Success(response) => response,
|
||||
Outcome::Error(status) => self.handle_error(status, request).await,
|
||||
Outcome::Forward((_, status)) => self.handle_error(status, request).await,
|
||||
}
|
||||
}
|
||||
Outcome::Forward((_, status)) => self.handle_error(status, request).await,
|
||||
Outcome::Error(status) => self.handle_error(status, request).await,
|
||||
};
|
||||
|
||||
// Set the cookies. Note that error responses will only include cookies
|
||||
// set by the error handler. See `handle_error` for more.
|
||||
let delta_jar = request.cookies().take_delta_jar();
|
||||
for cookie in delta_jar.delta() {
|
||||
response.adjoin_header(cookie);
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
/// Tries to find a `Responder` for a given `request`. It does this by
|
||||
/// routing the request and calling the handler for each matching route
|
||||
/// until one of the handlers returns success or error, or there are no
|
||||
/// additional routes to try (forward). The corresponding outcome for each
|
||||
/// condition is returned.
|
||||
#[inline]
|
||||
async fn route<'s, 'r: 's>(
|
||||
&'s self,
|
||||
request: &'r Request<'s>,
|
||||
mut data: Data<'r>,
|
||||
) -> route::Outcome<'r> {
|
||||
// Go through all matching routes until we fail or succeed or run out of
|
||||
// routes to try, in which case we forward with the last status.
|
||||
let mut status = Status::NotFound;
|
||||
for route in self.router.route(request) {
|
||||
// Retrieve and set the requests parameters.
|
||||
info_!("Matched: {}", route);
|
||||
request.set_route(route);
|
||||
|
||||
let name = route.name.as_deref();
|
||||
let outcome = handle(name, || route.handler.handle(request, data)).await
|
||||
.unwrap_or(Outcome::Error(Status::InternalServerError));
|
||||
|
||||
// Check if the request processing completed (Some) or if the
|
||||
// request needs to be forwarded. If it does, continue the loop
|
||||
// (None) to try again.
|
||||
info_!("{}", outcome.log_display());
|
||||
match outcome {
|
||||
o@Outcome::Success(_) | o@Outcome::Error(_) => return o,
|
||||
Outcome::Forward(forwarded) => (data, status) = forwarded,
|
||||
}
|
||||
}
|
||||
|
||||
error_!("No matching routes for {}.", request);
|
||||
Outcome::Forward((data, status))
|
||||
}
|
||||
|
||||
/// Invokes the handler with `req` for catcher with status `status`.
|
||||
///
|
||||
/// In order of preference, invoked handler is:
|
||||
/// * the user's registered handler for `status`
|
||||
/// * the user's registered `default` handler
|
||||
/// * Rocket's default handler for `status`
|
||||
///
|
||||
/// Return `Ok(result)` if the handler succeeded. Returns `Ok(Some(Status))`
|
||||
/// if the handler ran to completion but failed. Returns `Ok(None)` if the
|
||||
/// handler panicked while executing.
|
||||
async fn invoke_catcher<'s, 'r: 's>(
|
||||
&'s self,
|
||||
status: Status,
|
||||
req: &'r Request<'s>
|
||||
) -> Result<Response<'r>, Option<Status>> {
|
||||
// For now, we reset the delta state to prevent any modifications
|
||||
// from earlier, unsuccessful paths from being reflected in error
|
||||
// response. We may wish to relax this in the future.
|
||||
req.cookies().reset_delta();
|
||||
|
||||
if let Some(catcher) = self.router.catch(status, req) {
|
||||
warn_!("Responding with registered {} catcher.", catcher);
|
||||
let name = catcher.name.as_deref();
|
||||
handle(name, || catcher.handler.handle(status, req)).await
|
||||
.map(|result| result.map_err(Some))
|
||||
.unwrap_or_else(|| Err(None))
|
||||
} else {
|
||||
let code = status.code.blue().bold();
|
||||
warn_!("No {} catcher registered. Using Rocket default.", code);
|
||||
Ok(crate::catcher::default_handler(status, req))
|
||||
}
|
||||
}
|
||||
|
||||
// Invokes the catcher for `status`. Returns the response on success.
|
||||
//
|
||||
// On catcher error, the 500 error catcher is attempted. If _that_ errors,
|
||||
// the (infallible) default 500 error cather is used.
|
||||
pub(crate) async fn handle_error<'s, 'r: 's>(
|
||||
&'s self,
|
||||
mut status: Status,
|
||||
req: &'r Request<'s>
|
||||
) -> Response<'r> {
|
||||
// Dispatch to the `status` catcher.
|
||||
if let Ok(r) = self.invoke_catcher(status, req).await {
|
||||
return r;
|
||||
}
|
||||
|
||||
// If it fails and it's not a 500, try the 500 catcher.
|
||||
if status != Status::InternalServerError {
|
||||
error_!("Catcher failed. Attempting 500 error catcher.");
|
||||
status = Status::InternalServerError;
|
||||
if let Ok(r) = self.invoke_catcher(status, req).await {
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
// If it failed again or if it was already a 500, use Rocket's default.
|
||||
error_!("{} catcher failed. Using Rocket default 500.", status.code);
|
||||
crate::catcher::default_handler(Status::InternalServerError, req)
|
||||
}
|
||||
|
||||
pub(crate) async fn default_tcp_http_server<C>(mut self, ready: C) -> Result<Self, Error>
|
||||
where C: for<'a> Fn(&'a Self) -> BoxFuture<'a, ()>
|
||||
{
|
||||
use std::net::ToSocketAddrs;
|
||||
|
||||
// Determine the address we're going to serve on.
|
||||
let addr = format!("{}:{}", self.config.address, self.config.port);
|
||||
let mut addr = addr.to_socket_addrs()
|
||||
.map(|mut addrs| addrs.next().expect(">= 1 socket addr"))
|
||||
.map_err(|e| Error::new(ErrorKind::Io(e)))?;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
if self.config.tls_enabled() {
|
||||
if let Some(ref config) = self.config.tls {
|
||||
use crate::http::tls::TlsListener;
|
||||
|
||||
let conf = config.to_native_config().map_err(ErrorKind::Io)?;
|
||||
let l = TlsListener::bind(addr, conf).await.map_err(ErrorKind::TlsBind)?;
|
||||
addr = l.local_addr().unwrap_or(addr);
|
||||
self.config.address = addr.ip();
|
||||
self.config.port = addr.port();
|
||||
ready(&mut self).await;
|
||||
return self.http_server(l).await;
|
||||
}
|
||||
}
|
||||
|
||||
let l = TcpListener::bind(addr).await.map_err(ErrorKind::Bind)?;
|
||||
addr = l.local_addr().unwrap_or(addr);
|
||||
self.config.address = addr.ip();
|
||||
self.config.port = addr.port();
|
||||
ready(&mut self).await;
|
||||
self.http_server(l).await
|
||||
}
|
||||
|
||||
// TODO.async: Solidify the Listener APIs and make this function public
|
||||
pub(crate) async fn http_server<L>(self, listener: L) -> Result<Self, Error>
|
||||
where L: Listener + Send, <L as Listener>::Connection: Send + Unpin + 'static
|
||||
{
|
||||
// Emit a warning if we're not running inside of Rocket's async runtime.
|
||||
if self.config.profile == Config::DEBUG_PROFILE {
|
||||
tokio::task::spawn_blocking(|| {
|
||||
let this = std::thread::current();
|
||||
if !this.name().map_or(false, |s| s.starts_with("rocket-worker")) {
|
||||
warn!("Rocket is executing inside of a custom runtime.");
|
||||
info_!("Rocket's runtime is enabled via `#[rocket::main]` or `#[launch]`.");
|
||||
info_!("Forced shutdown is disabled. Runtime settings may be suboptimal.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Set up cancellable I/O from the given listener. Shutdown occurs when
|
||||
// `Shutdown` (`TripWire`) resolves. This can occur directly through a
|
||||
// notification or indirectly through an external signal which, when
|
||||
// received, results in triggering the notify.
|
||||
let shutdown = self.shutdown();
|
||||
let sig_stream = self.config.shutdown.signal_stream();
|
||||
let grace = self.config.shutdown.grace as u64;
|
||||
let mercy = self.config.shutdown.mercy as u64;
|
||||
|
||||
// Start a task that listens for external signals and notifies shutdown.
|
||||
if let Some(mut stream) = sig_stream {
|
||||
let shutdown = shutdown.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Some(sig) = stream.next().await {
|
||||
if shutdown.0.tripped() {
|
||||
warn!("Received {}. Shutdown already in progress.", sig);
|
||||
} else {
|
||||
warn!("Received {}. Requesting shutdown.", sig);
|
||||
}
|
||||
|
||||
shutdown.0.trip();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Save the keep-alive value for later use; we're about to move `self`.
|
||||
let keep_alive = self.config.keep_alive;
|
||||
|
||||
// Create the Hyper `Service`.
|
||||
let rocket = Arc::new(self);
|
||||
let service_fn = |conn: &CancellableIo<_, L::Connection>| {
|
||||
let rocket = rocket.clone();
|
||||
let connection = ConnectionMeta {
|
||||
remote: conn.peer_address(),
|
||||
client_certificates: conn.peer_certificates(),
|
||||
};
|
||||
|
||||
async move {
|
||||
Ok::<_, std::convert::Infallible>(hyper::service::service_fn(move |req| {
|
||||
hyper_service_fn(rocket.clone(), connection.clone(), req)
|
||||
}))
|
||||
}
|
||||
};
|
||||
|
||||
// NOTE: `hyper` uses `tokio::spawn()` as the default executor.
|
||||
let listener = CancellableListener::new(shutdown.clone(), listener, grace, mercy);
|
||||
let builder = hyper::server::Server::builder(Incoming::new(listener).nodelay(true));
|
||||
|
||||
#[cfg(feature = "http2")]
|
||||
let builder = builder.http2_keep_alive_interval(match keep_alive {
|
||||
0 => None,
|
||||
n => Some(Duration::from_secs(n as u64))
|
||||
async fn service(
|
||||
self: Arc<Self>,
|
||||
mut req: hyper::Request<hyper::body::Incoming>,
|
||||
connection: ConnectionMeta,
|
||||
) -> Result<hyper::Response<ReaderStream<ErasedResponse>>, http::Error> {
|
||||
let upgrade = hyper::upgrade::on(&mut req);
|
||||
let (parts, incoming) = req.into_parts();
|
||||
let request = ErasedRequest::new(self, parts, |rocket, parts| {
|
||||
Request::from_hyp(rocket, parts, connection).unwrap_or_else(|e| e)
|
||||
});
|
||||
|
||||
let server = builder
|
||||
.http1_keepalive(keep_alive != 0)
|
||||
.http1_preserve_header_case(true)
|
||||
.serve(hyper::service::make_service_fn(service_fn))
|
||||
.with_graceful_shutdown(shutdown.clone());
|
||||
let mut response = request.into_response(
|
||||
incoming,
|
||||
|incoming| Data::from(incoming),
|
||||
|rocket, request, data| Box::pin(rocket.preprocess(request, data)),
|
||||
|token, rocket, request, data| Box::pin(async move {
|
||||
if !request.errors.is_empty() {
|
||||
return rocket.dispatch_error(Status::BadRequest, request).await;
|
||||
}
|
||||
|
||||
// This deserves some explanation.
|
||||
//
|
||||
// This is largely to deal with Hyper's dreadful and largely nonexistent
|
||||
// handling of shutdown, in general, nevermind graceful.
|
||||
//
|
||||
// When Hyper receives a "graceful shutdown" request, it stops accepting
|
||||
// new requests. That's it. It continues to process existing requests
|
||||
// and outgoing responses forever and never cancels them. As a result,
|
||||
// Rocket must take it upon itself to cancel any existing I/O.
|
||||
//
|
||||
// To do so, Rocket wraps all connections in a `CancellableIo` struct,
|
||||
// an internal structure that gracefully closes I/O when it receives a
|
||||
// signal. That signal is the `shutdown` future. When the future
|
||||
// resolves, `CancellableIo` begins to terminate in grace, mercy, and
|
||||
// finally force close phases. Since all connections are wrapped in
|
||||
let mut response = rocket.dispatch(token, request, data).await;
|
||||
response.body_mut().size().await;
|
||||
response
|
||||
})
|
||||
).await;
|
||||
|
||||
let io_handler = response.to_io_handler(Rocket::extract_io_handler);
|
||||
if let Some(handler) = io_handler {
|
||||
let upgrade = upgrade.map_ok(IoStream::from).map_err(io::Error::other);
|
||||
tokio::task::spawn(io_handler_task(upgrade, handler));
|
||||
}
|
||||
|
||||
let mut builder = hyper::Response::builder();
|
||||
builder = builder.status(response.inner().status().code);
|
||||
for header in response.inner().headers().iter() {
|
||||
builder = builder.header(header.name().as_str(), header.value());
|
||||
}
|
||||
|
||||
if let Some(size) = response.inner().body().preset_size() {
|
||||
builder = builder.header("Content-Length", size);
|
||||
}
|
||||
|
||||
let chunk_size = response.inner().body().max_chunk_size();
|
||||
builder.body(ReaderStream::with_capacity(response, chunk_size))
|
||||
}
|
||||
}
|
||||
|
||||
async fn io_handler_task<S>(stream: S, mut handler: ErasedIoHandler)
|
||||
where S: Future<Output = io::Result<IoStream>>
|
||||
{
|
||||
let stream = match stream.await {
|
||||
Ok(stream) => stream,
|
||||
Err(e) => return warn_!("Upgrade failed: {e}"),
|
||||
};
|
||||
|
||||
info_!("Upgrade succeeded.");
|
||||
if let Err(e) = handler.take().io(stream).await {
|
||||
match e.kind() {
|
||||
io::ErrorKind::BrokenPipe => warn!("Upgrade I/O handler was closed."),
|
||||
e => error!("Upgrade I/O handler failed: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Rocket<Ignite> {
|
||||
pub(crate) async fn serve<L>(self, listener: L) -> Result<Self, crate::Error>
|
||||
where L: Listener + 'static
|
||||
{
|
||||
let mut builder = Builder::new(TokioExecutor::new());
|
||||
let keep_alive = Duration::from_secs(self.config.keep_alive.into());
|
||||
builder.http1()
|
||||
.half_close(true)
|
||||
.timer(TokioTimer::new())
|
||||
.keep_alive(keep_alive > Duration::ZERO)
|
||||
.preserve_header_case(true)
|
||||
.header_read_timeout(Duration::from_secs(15));
|
||||
|
||||
#[cfg(feature = "http2")] {
|
||||
builder.http2().timer(TokioTimer::new());
|
||||
if keep_alive > Duration::ZERO {
|
||||
builder.http2()
|
||||
.timer(TokioTimer::new())
|
||||
.keep_alive_interval(keep_alive / 4)
|
||||
.keep_alive_timeout(keep_alive);
|
||||
}
|
||||
}
|
||||
|
||||
let listener = listener.bounced().cancellable(self.shutdown(), &self.config.shutdown);
|
||||
let rocket = Arc::new(self.into_orbit(listener.socket_addr()?));
|
||||
let _ = tokio::spawn(Rocket::liftoff(rocket.clone())).await;
|
||||
|
||||
let (server, listener) = (Arc::new(builder), Arc::new(listener));
|
||||
while let Some(accept) = listener.accept_next().await {
|
||||
let (listener, rocket, server) = (listener.clone(), rocket.clone(), server.clone());
|
||||
tokio::spawn({
|
||||
let result = async move {
|
||||
let conn = TokioIo::new(listener.connect(accept).await?);
|
||||
let meta = ConnectionMeta::from(conn.inner());
|
||||
let service = service_fn(|req| rocket.clone().service(req, meta.clone()));
|
||||
let serve = pin!(server.serve_connection_with_upgrades(conn, service));
|
||||
match select(serve, rocket.shutdown()).await {
|
||||
Left((result, _)) => result,
|
||||
Right((_, mut conn)) => {
|
||||
conn.as_mut().graceful_shutdown();
|
||||
conn.await
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
result.inspect_err(crate::error::log_server_error)
|
||||
});
|
||||
}
|
||||
|
||||
// Rocket wraps all connections in a `CancellableIo` struct, an internal
|
||||
// structure that gracefully closes I/O when it receives a signal. That
|
||||
// signal is the `shutdown` future. When the future resolves,
|
||||
// `CancellableIo` begins to terminate in grace, mercy, and finally
|
||||
// force close phases. Since all connections are wrapped in
|
||||
// `CancellableIo`, this eventually ends all I/O.
|
||||
//
|
||||
// At that point, unless a user spawned an infinite, stand-alone task
|
||||
|
@ -543,69 +145,35 @@ impl Rocket<Orbit> {
|
|||
// we can return the owned instance of `Rocket`.
|
||||
//
|
||||
// Unfortunately, the Hyper `server` future resolves as soon as it has
|
||||
// finishes processing requests without respect for ongoing responses.
|
||||
// finished processing requests without respect for ongoing responses.
|
||||
// That is, `server` resolves even when there are running tasks that are
|
||||
// generating a response. So, `server` resolving implies little to
|
||||
// nothing about the state of connections. As a result, we depend on the
|
||||
// timing of grace + mercy + some buffer to determine when all
|
||||
// connections should be closed, thus all tasks should be complete, thus
|
||||
// all references to `Arc<Rocket>` should be dropped and we can get a
|
||||
// unique reference.
|
||||
tokio::pin!(server);
|
||||
tokio::select! {
|
||||
biased;
|
||||
// all references to `Arc<Rocket>` should be dropped and we can get back
|
||||
// a unique reference.
|
||||
info!("Shutting down. Waiting for shutdown fairings and pending I/O...");
|
||||
tokio::spawn({
|
||||
let rocket = rocket.clone();
|
||||
async move { rocket.fairings.handle_shutdown(&*rocket).await }
|
||||
});
|
||||
|
||||
_ = shutdown => {
|
||||
// Run shutdown fairings. We compute `sleep()` for grace periods
|
||||
// beforehand to ensure we don't add shutdown fairing completion
|
||||
// time, which is arbitrary, to these periods.
|
||||
info!("Shutdown requested. Waiting for pending I/O...");
|
||||
let grace_timer = sleep(Duration::from_secs(grace));
|
||||
let mercy_timer = sleep(Duration::from_secs(grace + mercy));
|
||||
let shutdown_timer = sleep(Duration::from_secs(grace + mercy + 1));
|
||||
rocket.fairings.handle_shutdown(&*rocket).await;
|
||||
let config = &rocket.config.shutdown;
|
||||
let wait = Duration::from_micros(250);
|
||||
for period in [wait, config.grace(), wait, config.mercy(), wait * 4] {
|
||||
if Arc::strong_count(&rocket) == 1 { break }
|
||||
sleep(period).await;
|
||||
}
|
||||
|
||||
tokio::pin!(grace_timer, mercy_timer, shutdown_timer);
|
||||
tokio::select! {
|
||||
biased;
|
||||
|
||||
result = &mut server => {
|
||||
if let Err(e) = result {
|
||||
warn!("Server failed while shutting down: {}", e);
|
||||
return Err(Error::shutdown(rocket.clone(), e));
|
||||
}
|
||||
|
||||
if Arc::strong_count(&rocket) != 1 { grace_timer.await; }
|
||||
if Arc::strong_count(&rocket) != 1 { mercy_timer.await; }
|
||||
if Arc::strong_count(&rocket) != 1 { shutdown_timer.await; }
|
||||
match Arc::try_unwrap(rocket) {
|
||||
Ok(rocket) => {
|
||||
info!("Graceful shutdown completed successfully.");
|
||||
Ok(rocket)
|
||||
}
|
||||
Err(rocket) => {
|
||||
warn!("Shutdown failed: outstanding background I/O.");
|
||||
Err(Error::shutdown(rocket, None))
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = &mut shutdown_timer => {
|
||||
warn!("Shutdown failed: server executing after timeouts.");
|
||||
return Err(Error::shutdown(rocket.clone(), None));
|
||||
},
|
||||
}
|
||||
match Arc::try_unwrap(rocket) {
|
||||
Ok(rocket) => {
|
||||
info!("Graceful shutdown completed successfully.");
|
||||
Ok(rocket.into_ignite())
|
||||
}
|
||||
result = &mut server => {
|
||||
match result {
|
||||
Ok(()) => {
|
||||
info!("Server shutdown nominally.");
|
||||
Ok(Arc::try_unwrap(rocket).map_err(|r| Error::shutdown(r, None))?)
|
||||
}
|
||||
Err(e) => {
|
||||
info!("Server failed prior to shutdown: {}:", e);
|
||||
Err(Error::shutdown(rocket.clone(), e))
|
||||
}
|
||||
}
|
||||
Err(rocket) => {
|
||||
warn!("Shutdown failed: outstanding background I/O.");
|
||||
Err(Error::new(ErrorKind::Shutdown(rocket)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -198,7 +198,7 @@ impl Fairing for Shield {
|
|||
}
|
||||
|
||||
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
|
||||
let force_hsts = rocket.config().tls_enabled()
|
||||
let force_hsts = rocket.endpoint().is_tls()
|
||||
&& rocket.figment().profile() != Config::DEBUG_PROFILE
|
||||
&& !self.is_enabled::<Hsts>();
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ use std::pin::Pin;
|
|||
use futures::FutureExt;
|
||||
|
||||
use crate::request::{FromRequest, Outcome, Request};
|
||||
use crate::trip_wire::TripWire;
|
||||
use crate::util::TripWire;
|
||||
|
||||
/// A request guard and future for graceful shutdown.
|
||||
///
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -11,6 +11,7 @@ pub enum KeyError {
|
|||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
Io(std::io::Error),
|
||||
Bind(Box<dyn std::error::Error + Send + 'static>),
|
||||
Tls(rustls::Error),
|
||||
Mtls(rustls::server::VerifierBuilderError),
|
||||
CertChain(std::io::Error),
|
||||
|
@ -29,6 +30,7 @@ impl std::fmt::Display for Error {
|
|||
CertChain(e) => write!(f, "failed to process certificate chain: {e}"),
|
||||
PrivKey(e) => write!(f, "failed to process private key: {e}"),
|
||||
CertAuth(e) => write!(f, "failed to process certificate authority: {e}"),
|
||||
Bind(e) => write!(f, "failed to bind to network interface: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -66,6 +68,7 @@ impl std::error::Error for Error {
|
|||
Error::CertChain(e) => Some(e),
|
||||
Error::PrivKey(e) => Some(e),
|
||||
Error::CertAuth(e) => Some(e),
|
||||
Error::Bind(e) => Some(&**e),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
mod error;
|
||||
pub(crate) mod config;
|
||||
pub(crate) mod util;
|
||||
|
||||
pub use error::Result;
|
||||
pub use config::{TlsConfig, CipherSuite};
|
||||
pub use error::Error;
|
|
@ -0,0 +1,52 @@
|
|||
use std::io;
|
||||
use std::task::{Poll, Context};
|
||||
use std::pin::Pin;
|
||||
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncRead, ReadBuf};
|
||||
|
||||
pin_project! {
|
||||
/// Stream for the [`chain`](super::AsyncReadExt::chain) method.
|
||||
#[must_use = "streams do nothing unless polled"]
|
||||
pub struct Chain<T, U> {
|
||||
#[pin]
|
||||
first: Option<T>,
|
||||
#[pin]
|
||||
second: U,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> Chain<T, U> {
|
||||
pub(crate) fn new(first: T, second: U) -> Self {
|
||||
Self { first: Some(first), second }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead, U: AsyncRead> Chain<T, U> {
|
||||
/// Gets references to the underlying readers in this `Chain`.
|
||||
pub fn get_ref(&self) -> (Option<&T>, &U) {
|
||||
(self.first.as_ref(), &self.second)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead, U: AsyncRead> AsyncRead for Chain<T, U> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let me = self.as_mut().project();
|
||||
if let Some(first) = me.first.as_pin_mut() {
|
||||
let init_rem = buf.remaining();
|
||||
futures::ready!(first.poll_read(cx, buf))?;
|
||||
if buf.remaining() == init_rem {
|
||||
self.as_mut().project().first.set(None);
|
||||
} else {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
|
||||
let me = self.as_mut().project();
|
||||
me.second.poll_read(cx, buf)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
use std::pin::Pin;
|
||||
use std::task::{Poll, Context};
|
||||
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
use futures::stream::Stream;
|
||||
use futures::ready;
|
||||
|
||||
/// Join two streams, `a` and `b`, into a new `Join` stream that returns items
|
||||
/// from both streams, biased to `a`, until `a` finishes. The joined stream
|
||||
/// completes when `a` completes, irrespective of `b`. If `b` stops producing
|
||||
/// values, then the joined stream acts exactly like a fused `a`.
|
||||
///
|
||||
/// Values are biased to those of `a`: if `a` provides a value, it is always
|
||||
/// emitted before a value provided by `b`. In other words, values from `b` are
|
||||
/// emitted when and only when `a` is not producing a value.
|
||||
pub fn join<A: Stream, B: Stream>(a: A, b: B) -> Join<A, B> {
|
||||
Join { a, b: Some(b), done: false, }
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// Stream returned by [`join`].
|
||||
pub struct Join<T, U> {
|
||||
#[pin]
|
||||
a: T,
|
||||
#[pin]
|
||||
b: Option<U>,
|
||||
// Set when `a` returns `None`.
|
||||
done: bool,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> Stream for Join<T, U>
|
||||
where T: Stream,
|
||||
U: Stream<Item = T::Item>,
|
||||
{
|
||||
type Item = T::Item;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T::Item>> {
|
||||
if self.done {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
|
||||
let me = self.as_mut().project();
|
||||
match me.a.poll_next(cx) {
|
||||
Poll::Ready(opt) => {
|
||||
*me.done = opt.is_none();
|
||||
Poll::Ready(opt)
|
||||
},
|
||||
Poll::Pending => match me.b.as_pin_mut() {
|
||||
None => Poll::Pending,
|
||||
Some(b) => match ready!(b.poll_next(cx)) {
|
||||
Some(value) => Poll::Ready(Some(value)),
|
||||
None => {
|
||||
self.as_mut().project().b.set(None);
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
let (left_low, left_high) = self.a.size_hint();
|
||||
let (right_low, right_high) = self.b.as_ref()
|
||||
.map(|b| b.size_hint())
|
||||
.unwrap_or_default();
|
||||
|
||||
let low = left_low.saturating_add(right_low);
|
||||
let high = match (left_high, right_high) {
|
||||
(Some(h1), Some(h2)) => h1.checked_add(h2),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
(low, high)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
mod chain;
|
||||
mod tripwire;
|
||||
mod reader_stream;
|
||||
mod join;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub mod unix;
|
||||
|
||||
pub use chain::Chain;
|
||||
pub use tripwire::TripWire;
|
||||
pub use reader_stream::ReaderStream;
|
||||
pub use join::join;
|
|
@ -0,0 +1,124 @@
|
|||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::stream::Stream;
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::AsyncRead;
|
||||
|
||||
pin_project! {
|
||||
/// Convert an [`AsyncRead`] into a [`Stream`] of byte chunks.
|
||||
///
|
||||
/// This stream is fused. It performs the inverse operation of
|
||||
/// [`StreamReader`].
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// # #[tokio::main]
|
||||
/// # async fn main() -> std::io::Result<()> {
|
||||
/// use tokio_stream::StreamExt;
|
||||
/// use tokio_util::io::ReaderStream;
|
||||
///
|
||||
/// // Create a stream of data.
|
||||
/// let data = b"hello, world!";
|
||||
/// let mut stream = ReaderStream::new(&data[..]);
|
||||
///
|
||||
/// // Read all of the chunks into a vector.
|
||||
/// let mut stream_contents = Vec::new();
|
||||
/// while let Some(chunk) = stream.next().await {
|
||||
/// stream_contents.extend_from_slice(&chunk?);
|
||||
/// }
|
||||
///
|
||||
/// // Once the chunks are concatenated, we should have the
|
||||
/// // original data.
|
||||
/// assert_eq!(stream_contents, data);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// [`AsyncRead`]: tokio::io::AsyncRead
|
||||
/// [`StreamReader`]: crate::io::StreamReader
|
||||
/// [`Stream`]: futures_core::Stream
|
||||
#[derive(Debug)]
|
||||
pub struct ReaderStream<R> {
|
||||
// Reader itself.
|
||||
//
|
||||
// This value is `None` if the stream has terminated.
|
||||
#[pin]
|
||||
reader: R,
|
||||
// Working buffer, used to optimize allocations.
|
||||
buf: BytesMut,
|
||||
capacity: usize,
|
||||
done: bool,
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead> ReaderStream<R> {
|
||||
/// Convert an [`AsyncRead`] into a [`Stream`] with item type
|
||||
/// `Result<Bytes, std::io::Error>`,
|
||||
/// with a specific read buffer initial capacity.
|
||||
///
|
||||
/// [`AsyncRead`]: tokio::io::AsyncRead
|
||||
/// [`Stream`]: futures_core::Stream
|
||||
pub fn with_capacity(reader: R, capacity: usize) -> Self {
|
||||
ReaderStream {
|
||||
reader: reader,
|
||||
buf: BytesMut::with_capacity(capacity),
|
||||
capacity,
|
||||
done: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead> Stream for ReaderStream<R> {
|
||||
type Item = std::io::Result<Bytes>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
use tokio_util::io::poll_read_buf;
|
||||
|
||||
let mut this = self.as_mut().project();
|
||||
|
||||
if *this.done {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
|
||||
if this.buf.capacity() == 0 {
|
||||
this.buf.reserve(*this.capacity);
|
||||
}
|
||||
|
||||
let reader = this.reader;
|
||||
match poll_read_buf(reader, cx, &mut this.buf) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(Err(err)) => {
|
||||
*this.done = true;
|
||||
Poll::Ready(Some(Err(err)))
|
||||
}
|
||||
Poll::Ready(Ok(0)) => {
|
||||
*this.done = true;
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Ready(Ok(_)) => {
|
||||
let chunk = this.buf.split();
|
||||
Poll::Ready(Some(Ok(chunk.freeze())))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
// self.reader.size_hint()
|
||||
// }
|
||||
}
|
||||
|
||||
impl<R: AsyncRead> hyper::body::Body for ReaderStream<R> {
|
||||
type Data = bytes::Bytes;
|
||||
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn poll_frame(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<hyper::body::Frame<Self::Data>, Self::Error>>> {
|
||||
self.poll_next(cx).map_ok(hyper::body::Frame::data)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
use std::io;
|
||||
use std::os::fd::AsRawFd;
|
||||
|
||||
pub fn lock_exlusive_nonblocking<T: AsRawFd>(file: &T) -> io::Result<()> {
|
||||
let raw_fd = file.as_raw_fd();
|
||||
let res = unsafe {
|
||||
libc::flock(raw_fd, libc::LOCK_EX | libc::LOCK_NB)
|
||||
};
|
||||
|
||||
match res {
|
||||
0 => Ok(()),
|
||||
_ => Err(io::Error::last_os_error()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unlock_nonblocking<T: AsRawFd>(file: &T) -> io::Result<()> {
|
||||
let res = unsafe {
|
||||
libc::flock(file.as_raw_fd(), libc::LOCK_UN | libc::LOCK_NB)
|
||||
};
|
||||
|
||||
match res {
|
||||
0 => Ok(()),
|
||||
_ => Err(io::Error::last_os_error()),
|
||||
}
|
||||
}
|
|
@ -1,8 +1,9 @@
|
|||
#![cfg(feature = "tls")]
|
||||
|
||||
use rocket::fs::relative;
|
||||
use rocket::config::{Config, TlsConfig, CipherSuite};
|
||||
use rocket::local::asynchronous::Client;
|
||||
use rocket::tls::{TlsConfig, CipherSuite};
|
||||
use rocket::figment::providers::Serialized;
|
||||
|
||||
#[rocket::async_test]
|
||||
async fn can_launch_tls() {
|
||||
|
@ -15,9 +16,8 @@ async fn can_launch_tls() {
|
|||
CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
]);
|
||||
|
||||
let rocket = rocket::custom(Config { tls: Some(tls), ..Config::debug_default() });
|
||||
let client = Client::debug(rocket).await.unwrap();
|
||||
|
||||
let config = rocket::Config::figment().merge(Serialized::defaults(tls));
|
||||
let client = Client::debug(rocket::custom(config)).await.unwrap();
|
||||
client.rocket().shutdown().notify();
|
||||
client.rocket().shutdown().await;
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use std::net::{SocketAddr, Ipv4Addr};
|
||||
|
||||
use rocket::config::Config;
|
||||
use rocket::fairing::AdHoc;
|
||||
use rocket::futures::channel::oneshot;
|
||||
|
@ -5,13 +7,13 @@ use rocket::futures::channel::oneshot;
|
|||
#[rocket::async_test]
|
||||
async fn on_ignite_fairing_can_inspect_port() {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let rocket = rocket::custom(Config { port: 0, ..Config::debug_default() })
|
||||
let rocket = rocket::custom(Config::debug_default())
|
||||
.attach(AdHoc::on_liftoff("Send Port -> Channel", move |rocket| {
|
||||
Box::pin(async move {
|
||||
tx.send(rocket.config().port).unwrap();
|
||||
tx.send(rocket.endpoint().tcp().unwrap().port()).unwrap();
|
||||
})
|
||||
}));
|
||||
|
||||
rocket::tokio::spawn(rocket.launch());
|
||||
rocket::tokio::spawn(rocket.launch_on(SocketAddr::from((Ipv4Addr::LOCALHOST, 0))));
|
||||
assert_ne!(rx.await.unwrap(), 0);
|
||||
}
|
||||
|
|
|
@ -155,7 +155,7 @@ fn inner_sentinels_detected() {
|
|||
|
||||
impl<'r, 'o: 'r> response::Responder<'r, 'o> for ResponderSentinel {
|
||||
fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
|
||||
todo!()
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -8,19 +8,14 @@ macro_rules! relative {
|
|||
|
||||
#[test]
|
||||
fn tls_config_from_source() {
|
||||
use rocket::config::{Config, TlsConfig};
|
||||
use rocket::figment::Figment;
|
||||
use rocket::tls::TlsConfig;
|
||||
use rocket::figment::{Figment, providers::Serialized};
|
||||
|
||||
let cert_path = relative!("examples/tls/private/cert.pem");
|
||||
let key_path = relative!("examples/tls/private/key.pem");
|
||||
let config = TlsConfig::from_paths(cert_path, key_path);
|
||||
|
||||
let rocket_config = Config {
|
||||
tls: Some(TlsConfig::from_paths(cert_path, key_path)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let config: Config = Figment::from(rocket_config).extract().unwrap();
|
||||
let tls = config.tls.expect("have TLS config");
|
||||
let tls: TlsConfig = Figment::from(Serialized::globals(config)).extract().unwrap();
|
||||
assert_eq!(tls.certs().unwrap_left(), cert_path);
|
||||
assert_eq!(tls.key().unwrap_left(), key_path);
|
||||
}
|
||||
|
|
|
@ -6,15 +6,11 @@ async fn test_config(profile: &str) {
|
|||
let config = rocket.config();
|
||||
match &*profile {
|
||||
"debug" => {
|
||||
assert_eq!(config.address, std::net::Ipv4Addr::LOCALHOST);
|
||||
assert_eq!(config.port, 8000);
|
||||
assert_eq!(config.workers, 1);
|
||||
assert_eq!(config.keep_alive, 0);
|
||||
assert_eq!(config.log_level, LogLevel::Normal);
|
||||
}
|
||||
"release" => {
|
||||
assert_eq!(config.address, std::net::Ipv4Addr::LOCALHOST);
|
||||
assert_eq!(config.port, 8000);
|
||||
assert_eq!(config.workers, 12);
|
||||
assert_eq!(config.keep_alive, 5);
|
||||
assert_eq!(config.log_level, LogLevel::Critical);
|
||||
|
|
|
@ -74,19 +74,8 @@ fn hello(lang: Option<Lang>, opt: Options<'_>) -> String {
|
|||
|
||||
#[launch]
|
||||
fn rocket() -> _ {
|
||||
use rocket::fairing::AdHoc;
|
||||
|
||||
rocket::build()
|
||||
.mount("/", routes![hello])
|
||||
.mount("/hello", routes![world, mir])
|
||||
.mount("/wave", routes![wave])
|
||||
.attach(AdHoc::on_request("Compatibility Normalizer", |req, _| Box::pin(async move {
|
||||
if !req.uri().is_normalized_nontrailing() {
|
||||
let normal = req.uri().clone().into_normalized_nontrailing();
|
||||
warn!("Incoming request URI was normalized for compatibility.");
|
||||
info_!("{} -> {}", req.uri(), normal);
|
||||
req.set_uri(normal);
|
||||
}
|
||||
})))
|
||||
|
||||
}
|
||||
|
|
|
@ -1,33 +1,38 @@
|
|||
//! Redirect all HTTP requests to HTTPs.
|
||||
|
||||
use std::sync::OnceLock;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use rocket::http::Status;
|
||||
use rocket::log::LogLevel;
|
||||
use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite, Config};
|
||||
use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite};
|
||||
use rocket::fairing::{Fairing, Info, Kind};
|
||||
use rocket::response::Redirect;
|
||||
|
||||
use yansi::Paint;
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct Redirector(u16);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Redirector {
|
||||
pub listen_port: u16,
|
||||
pub tls_port: OnceLock<u16>,
|
||||
pub struct Config {
|
||||
server: rocket::Config,
|
||||
tls_addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl Redirector {
|
||||
pub fn on(port: u16) -> Self {
|
||||
Redirector { listen_port: port, tls_port: OnceLock::new() }
|
||||
Redirector(port)
|
||||
}
|
||||
|
||||
// Route function that gets called on every single request.
|
||||
fn redirect<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> {
|
||||
// FIXME: Check the host against a whitelist!
|
||||
let redirector = req.rocket().state::<Self>().expect("managed Self");
|
||||
let config = req.rocket().state::<Config>().expect("managed Self");
|
||||
if let Some(host) = req.host() {
|
||||
let domain = host.domain();
|
||||
let https_uri = match redirector.tls_port.get() {
|
||||
Some(443) | None => format!("https://{domain}{}", req.uri()),
|
||||
Some(port) => format!("https://{domain}:{port}{}", req.uri()),
|
||||
let https_uri = match config.tls_addr.port() {
|
||||
443 => format!("https://{domain}{}", req.uri()),
|
||||
port => format!("https://{domain}:{port}{}", req.uri()),
|
||||
};
|
||||
|
||||
route::Outcome::from(req, Redirect::permanent(https_uri)).pin()
|
||||
|
@ -37,21 +42,12 @@ impl Redirector {
|
|||
}
|
||||
|
||||
// Launch an instance of Rocket than handles redirection on `self.port`.
|
||||
pub async fn try_launch(self, mut config: Config) -> Result<Rocket<Ignite>, Error> {
|
||||
use yansi::Paint;
|
||||
pub async fn try_launch(self, config: Config) -> Result<Rocket<Ignite>, Error> {
|
||||
use rocket::http::Method::*;
|
||||
|
||||
// Determine the port TLS is being served on.
|
||||
let tls_port = self.tls_port.get_or_init(|| config.port);
|
||||
|
||||
// Adjust config for redirector: disable TLS, set port, disable logging.
|
||||
config.tls = None;
|
||||
config.port = self.listen_port;
|
||||
config.log_level = LogLevel::Critical;
|
||||
|
||||
info!("{}{}", "🔒 ".mask(), "HTTP -> HTTPS Redirector:".magenta());
|
||||
info_!("redirecting on insecure port {} to TLS port {}",
|
||||
self.listen_port.yellow(), tls_port.green());
|
||||
info_!("redirecting insecure port {} to TLS port {}",
|
||||
self.0.yellow(), config.tls_addr.port().green());
|
||||
|
||||
// Build a vector of routes to `redirect` on `<path..>` for each method.
|
||||
let redirects = [Get, Put, Post, Delete, Options, Head, Trace, Connect, Patch]
|
||||
|
@ -59,10 +55,11 @@ impl Redirector {
|
|||
.map(|m| Route::new(m, "/<path..>", Self::redirect))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
rocket::custom(config)
|
||||
.manage(self)
|
||||
let addr = SocketAddr::new(config.tls_addr.ip(), self.0);
|
||||
rocket::custom(&config.server)
|
||||
.manage(config)
|
||||
.mount("/", redirects)
|
||||
.launch()
|
||||
.launch_on(addr)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
@ -76,8 +73,24 @@ impl Fairing for Redirector {
|
|||
}
|
||||
}
|
||||
|
||||
async fn on_liftoff(&self, rkt: &Rocket<Orbit>) {
|
||||
let (this, shutdown, config) = (self.clone(), rkt.shutdown(), rkt.config().clone());
|
||||
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
|
||||
let Some(tls_addr) = rocket.endpoint().tls().and_then(|tls| tls.tcp()) else {
|
||||
info!("{}{}", "🔒 ".mask(), "HTTP -> HTTPS Redirector:".magenta());
|
||||
warn_!("Main instance is not being served over TLS/TCP.");
|
||||
warn_!("Redirector refusing to start.");
|
||||
return;
|
||||
};
|
||||
|
||||
let config = Config {
|
||||
tls_addr,
|
||||
server: rocket::Config {
|
||||
log_level: LogLevel::Critical,
|
||||
..rocket.config().clone()
|
||||
},
|
||||
};
|
||||
|
||||
let this = *self;
|
||||
let shutdown = rocket.shutdown();
|
||||
let _ = rocket::tokio::spawn(async move {
|
||||
if let Err(e) = this.try_launch(config).await {
|
||||
error!("Failed to start HTTP -> HTTPS redirector.");
|
||||
|
|
|
@ -1,11 +1,21 @@
|
|||
use std::fs::{self, File};
|
||||
|
||||
use rocket::http::{CookieJar, Cookie};
|
||||
use rocket::local::blocking::Client;
|
||||
use rocket::fs::relative;
|
||||
|
||||
#[get("/cookie")]
|
||||
fn cookie(jar: &CookieJar<'_>) {
|
||||
jar.add(("k1", "v1"));
|
||||
jar.add_private(("k2", "v2"));
|
||||
|
||||
jar.add(Cookie::build(("k1u", "v1u")).secure(false));
|
||||
jar.add_private(Cookie::build(("k2u", "v2u")).secure(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hello_mutual() {
|
||||
let client = Client::tracked(super::rocket()).unwrap();
|
||||
let client = Client::tracked_secure(super::rocket()).unwrap();
|
||||
let cert_paths = fs::read_dir(relative!("private")).unwrap()
|
||||
.map(|entry| entry.unwrap().path().to_string_lossy().into_owned())
|
||||
.filter(|path| path.ends_with("_cert.pem") && !path.ends_with("ca_cert.pem"));
|
||||
|
@ -23,35 +33,43 @@ fn hello_mutual() {
|
|||
|
||||
#[test]
|
||||
fn secure_cookies() {
|
||||
use rocket::http::{CookieJar, Cookie};
|
||||
let rocket = super::rocket().mount("/", routes![cookie]);
|
||||
let client = Client::tracked_secure(rocket).unwrap();
|
||||
|
||||
#[get("/cookie")]
|
||||
fn cookie(jar: &CookieJar<'_>) {
|
||||
jar.add(("k1", "v1"));
|
||||
jar.add_private(("k2", "v2"));
|
||||
|
||||
jar.add(Cookie::build(("k1u", "v1u")).secure(false));
|
||||
jar.add_private(Cookie::build(("k2u", "v2u")).secure(false));
|
||||
}
|
||||
|
||||
let client = Client::tracked(super::rocket().mount("/", routes![cookie])).unwrap();
|
||||
let response = client.get("/cookie").dispatch();
|
||||
|
||||
let c1 = response.cookies().get("k1").unwrap();
|
||||
assert_eq!(c1.secure(), Some(true));
|
||||
|
||||
let c2 = response.cookies().get_private("k2").unwrap();
|
||||
let c3 = response.cookies().get("k1u").unwrap();
|
||||
let c4 = response.cookies().get_private("k2u").unwrap();
|
||||
|
||||
assert_eq!(c1.secure(), Some(true));
|
||||
assert_eq!(c2.secure(), Some(true));
|
||||
assert_ne!(c3.secure(), Some(true));
|
||||
assert_ne!(c4.secure(), Some(true));
|
||||
}
|
||||
|
||||
let c1 = response.cookies().get("k1u").unwrap();
|
||||
assert_ne!(c1.secure(), Some(true));
|
||||
#[test]
|
||||
fn insecure_cookies() {
|
||||
let rocket = super::rocket().mount("/", routes![cookie]);
|
||||
let client = Client::tracked(rocket).unwrap();
|
||||
|
||||
let c2 = response.cookies().get_private("k2u").unwrap();
|
||||
assert_ne!(c2.secure(), Some(true));
|
||||
let response = client.get("/cookie").dispatch();
|
||||
let c1 = response.cookies().get("k1").unwrap();
|
||||
let c2 = response.cookies().get_private("k2").unwrap();
|
||||
let c3 = response.cookies().get("k1u").unwrap();
|
||||
let c4 = response.cookies().get_private("k2u").unwrap();
|
||||
|
||||
assert_eq!(c1.secure(), None);
|
||||
assert_eq!(c2.secure(), None);
|
||||
assert_eq!(c3.secure(), None);
|
||||
assert_eq!(c4.secure(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hello_world() {
|
||||
use rocket::listener::DefaultListener;
|
||||
use rocket::config::{Config, SecretKey};
|
||||
|
||||
let profiles = [
|
||||
"rsa_sha256",
|
||||
"ecdsa_nistp256_sha256_pkcs8",
|
||||
|
@ -61,11 +79,20 @@ fn hello_world() {
|
|||
"ed25519",
|
||||
];
|
||||
|
||||
// TODO: Testing doesn't actually read keys since we don't do TLS locally.
|
||||
for profile in profiles {
|
||||
let config = rocket::Config::figment().select(profile);
|
||||
let client = Client::tracked(super::rocket().configure(config)).unwrap();
|
||||
let config = Config {
|
||||
secret_key: SecretKey::generate().unwrap(),
|
||||
..Config::debug_default()
|
||||
};
|
||||
|
||||
let figment = Config::figment().merge(config).select(profile);
|
||||
let client = Client::tracked_secure(super::rocket().configure(figment)).unwrap();
|
||||
let response = client.get("/").dispatch();
|
||||
assert_eq!(response.into_string().unwrap(), "Hello, world!");
|
||||
|
||||
let figment = client.rocket().figment();
|
||||
let listener: DefaultListener = figment.extract().unwrap();
|
||||
assert_eq!(figment.profile(), profile);
|
||||
listener.tls.as_ref().unwrap().validate().expect("valid TLS config");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
<div id="log"></div>
|
||||
</body>
|
||||
<script language="javascript" type="text/javascript">
|
||||
var wsUri = "ws://127.0.0.1:8000/echo";
|
||||
var wsUri = "ws://127.0.0.1:8000/echo?raw";
|
||||
var log;
|
||||
|
||||
function init() {
|
||||
|
|
|
@ -20,8 +20,10 @@ fi
|
|||
echo ":::: Generating the docs..."
|
||||
pushd "${PROJECT_ROOT}" > /dev/null 2>&1
|
||||
# Set the crate version and fill in missing doc URLs with docs.rs links.
|
||||
RUSTDOCFLAGS="-Zunstable-options --crate-version ${DOC_VERSION} --extern-html-root-url rocket=https://api.rocket.rs/rocket/" \
|
||||
cargo doc -Zrustdoc-map --no-deps --all-features \
|
||||
RUSTDOCFLAGS="-Z unstable-options \
|
||||
--crate-version ${DOC_VERSION} \
|
||||
--enable-index-page" \
|
||||
cargo doc -Zrustdoc-map --no-deps --all-features \
|
||||
-p rocket \
|
||||
-p rocket_db_pools \
|
||||
-p rocket_sync_db_pools \
|
||||
|
|
|
@ -126,10 +126,11 @@ function test_contrib() {
|
|||
|
||||
function test_core() {
|
||||
FEATURES=(
|
||||
tokio-macros
|
||||
http2
|
||||
secrets
|
||||
tls
|
||||
mtls
|
||||
http2
|
||||
json
|
||||
msgpack
|
||||
uuid
|
||||
|
|
Loading…
Reference in New Issue