Update to hyper 1. Enable custom + unix listeners.

This commit completely rewrites Rocket's HTTP serving. In addition to
significant internal cleanup, this commit introduces the following major
features:

  * Support for custom, external listeners in the `listener` module.

    The new `listener` module contains new `Bindable`, `Listener`, and
    `Connection` traits which enable composable, external
    implementations of connection listeners. Rocket can launch on any
    `Listener`, or anything that can be used to create a listener
    (`Bindable`), via a new `launch_on()` method.

  * Support for Unix domain socket listeners out of the box.

    The default listener backwards compatibly supports listening on Unix
    domain sockets. To do so, configure an `address` of
    `unix:path/to/socket` and optional set `reuse` to `true` (the
    default) or `false` which controls whether Rocket will handle
    creating and deleting the unix domain socket.

In addition to these new features, this commit makes the following major
improvements:

  * Rocket now depends on hyper 1.

  * Rocket no longer depends on hyper to handle connections. This allows
    us to handle more connection failure conditions which results in an
    overall more robust server with fewer dependencies.

  * Logic to work around hyper's inability to reference incoming request
    data in the response results in a 15% performance improvement.

  * `Client`s can be marked secure with `Client::{un}tracked_secure()`,
    allowing Rocket to treat local connections as running under TLS.

  * The `macros` feature of `tokio` is no longer used by Rocket itself.
    Dependencies can take advantage of this reduction in compile-time
    cost by disabling the new default feature `tokio-macros`.

  * A new `TlsConfig::validate()` method allows checking a TLS config.

  * New `TlsConfig::{certs,key}_reader()`,
    `MtlsConfig::ca_certs_reader()` methods return `BufReader`s, which
    allow reading the configured certs and key directly.

  * A new `NamedFile::open_with()` constructor allows specifying
    `OpenOptions`.

These improvements resulted in the following breaking changes:

  * The MSRV is now 1.74.
  * `hyper` is no longer exported from `rocket::http`.
  * `IoHandler::io` takes `Box<Self>` instead of `Pin<Box<Self>>`.
    - Use `Box::into_pin(self)` to recover the previous type.
  * `Response::upgrade()` now returns an `&mut dyn IoHandler`, not
    `Pin<& mut _>`.
  * `Config::{address,port,tls,mtls}` methods have been removed.
    - Use methods on `Rocket::endpoint()` instead.
  * `TlsConfig` was moved to `tls::TlsConfig`.
  * `MutualTls` was renamed and moved to `mtls::MtlsConfig`.
  * `ErrorKind::TlsBind` was removed.
  * The second field of `ErrorKind::Shutdown` was removed.
  * `{Local}Request::{set_}remote()` methods take/return an `Endpoint`.
  * `Client::new()` was removed; it was previously deprecated.

Internally, the following major changes were made:

  * A new `async_bound` attribute macro was introduced to allow setting
    bounds on futures returned by `async fn`s in traits while
    maintaining good docs.

  * All utility functionality was moved to a new `util` module.

Resolves #2671.
Resolves #1070.
This commit is contained in:
Sergio Benitez 2023-12-19 14:32:11 -08:00
parent e9b568d9b2
commit fd294049c7
90 changed files with 3869 additions and 3246 deletions

View File

@ -33,7 +33,6 @@ use crate::result::{Result, Error};
/// ///
/// [`StreamExt`]: rocket::futures::StreamExt /// [`StreamExt`]: rocket::futures::StreamExt
/// [`SinkExt`]: rocket::futures::SinkExt /// [`SinkExt`]: rocket::futures::SinkExt
pub struct DuplexStream(tokio_tungstenite::WebSocketStream<IoStream>); pub struct DuplexStream(tokio_tungstenite::WebSocketStream<IoStream>);
impl DuplexStream { impl DuplexStream {

View File

@ -1,5 +1,4 @@
use std::io; use std::io;
use std::pin::Pin;
use rocket::data::{IoHandler, IoStream}; use rocket::data::{IoHandler, IoStream};
use rocket::futures::{self, StreamExt, SinkExt, future::BoxFuture, stream::SplitStream}; use rocket::futures::{self, StreamExt, SinkExt, future::BoxFuture, stream::SplitStream};
@ -37,10 +36,6 @@ pub struct WebSocket {
} }
impl WebSocket { impl WebSocket {
fn new(key: String) -> WebSocket {
WebSocket { config: Config::default(), key }
}
/// Change the default connection configuration to `config`. /// Change the default connection configuration to `config`.
/// ///
/// # Example /// # Example
@ -202,7 +197,9 @@ impl<'r> FromRequest<'r> for WebSocket {
let is_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13"); let is_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13");
let key = headers.get_one("Sec-WebSocket-Key").map(|k| derive_accept_key(k.as_bytes())); let key = headers.get_one("Sec-WebSocket-Key").map(|k| derive_accept_key(k.as_bytes()));
match key { match key {
Some(key) if is_upgrade && is_ws && is_13 => Outcome::Success(WebSocket::new(key)), Some(key) if is_upgrade && is_ws && is_13 => {
Outcome::Success(WebSocket { key, config: Config::default() })
},
Some(_) | None => Outcome::Forward(Status::BadRequest) Some(_) | None => Outcome::Forward(Status::BadRequest)
} }
} }
@ -232,9 +229,9 @@ impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S>
#[rocket::async_trait] #[rocket::async_trait]
impl IoHandler for Channel<'_> { impl IoHandler for Channel<'_> {
async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> { async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
let channel = Pin::into_inner(self); let stream = DuplexStream::new(io, self.ws.config).await;
let result = (channel.handler)(DuplexStream::new(io, channel.ws.config).await).await; let result = (self.handler)(stream).await;
handle_result(result).map(|_| ()) handle_result(result).map(|_| ())
} }
} }
@ -243,9 +240,9 @@ impl IoHandler for Channel<'_> {
impl<'r, S> IoHandler for MessageStream<'r, S> impl<'r, S> IoHandler for MessageStream<'r, S>
where S: futures::Stream<Item = Result<Message>> + Send + 'r where S: futures::Stream<Item = Result<Message>> + Send + 'r
{ {
async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> { async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
let (mut sink, source) = DuplexStream::new(io, self.ws.config).await.split(); let (mut sink, source) = DuplexStream::new(io, self.ws.config).await.split();
let stream = (Pin::into_inner(self).handler)(source); let stream = (self.handler)(source);
rocket::tokio::pin!(stream); rocket::tokio::pin!(stream);
while let Some(msg) = stream.next().await { while let Some(msg) = stream.next().await {
let result = match msg { let result = match msg {

View File

@ -0,0 +1,61 @@
use proc_macro2::{TokenStream, Span};
use devise::{Spanned, Result, ext::SpanDiagnosticExt};
use syn::{Token, parse_quote, parse_quote_spanned};
use syn::{TraitItemFn, TypeParamBound, ReturnType, Attribute};
use syn::punctuated::Punctuated;
use syn::parse::Parser;
fn _async_bound(
args: proc_macro::TokenStream,
input: proc_macro::TokenStream
) -> Result<TokenStream> {
let bounds = <Punctuated<TypeParamBound, Token![+]>>::parse_terminated.parse(args)?;
if bounds.is_empty() {
return Ok(input.into());
}
let mut func: TraitItemFn = syn::parse(input)?;
let original: TraitItemFn = func.clone();
if !func.sig.asyncness.is_some() {
let diag = Span::call_site()
.error("attribute can only be applied to async fns")
.span_help(func.sig.span(), "this fn declaration must be `async`");
return Err(diag);
}
let doc: Attribute = parse_quote! {
#[doc = concat!(
"# Future Bounds",
"\n",
"**The `Future` generated by this `async fn` must be `", stringify!(#bounds), "`**."
)]
};
func.sig.asyncness = None;
func.sig.output = match func.sig.output {
ReturnType::Type(arrow, ty) => parse_quote_spanned!(ty.span() =>
#arrow impl ::core::future::Future<Output = #ty> + #bounds
),
default@ReturnType::Default => default
};
Ok(quote! {
#[cfg(all(not(doc), rust_analyzer))]
#original
#[cfg(all(doc, not(rust_analyzer)))]
#doc
#original
#[cfg(not(any(doc, rust_analyzer)))]
#func
})
}
pub fn async_bound(
args: proc_macro::TokenStream,
input: proc_macro::TokenStream
) -> TokenStream {
_async_bound(args, input).unwrap_or_else(|d| d.emit_as_item_tokens())
}

View File

@ -2,3 +2,4 @@ pub mod entry;
pub mod catch; pub mod catch;
pub mod route; pub mod route;
pub mod param; pub mod param;
pub mod async_bound;

View File

@ -331,7 +331,7 @@ fn codegen_route(route: Route) -> Result<TokenStream> {
let internal_uri_macro = internal_uri_macro_decl(&route); let internal_uri_macro = internal_uri_macro_decl(&route);
let responder_outcome = responder_outcome_expr(&route); let responder_outcome = responder_outcome_expr(&route);
let method = route.attr.method; let method = &route.attr.method;
let uri = route.attr.uri.to_string(); let uri = route.attr.uri.to_string();
let rank = Optional(route.attr.rank); let rank = Optional(route.attr.rank);
let format = Optional(route.attr.format.as_ref()); let format = Optional(route.attr.format.as_ref());

View File

@ -13,7 +13,7 @@ pub struct Status(pub http::Status);
#[derive(Debug)] #[derive(Debug)]
pub struct MediaType(pub http::MediaType); pub struct MediaType(pub http::MediaType);
#[derive(Debug, Copy, Clone)] #[derive(Debug, Clone)]
pub struct Method(pub http::Method); pub struct Method(pub http::Method);
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -108,7 +108,7 @@ const VALID_METHODS: &[http::Method] = &[
impl FromMeta for Method { impl FromMeta for Method {
fn from_meta(meta: &MetaItem) -> Result<Self> { fn from_meta(meta: &MetaItem) -> Result<Self> {
let span = meta.value_span(); let span = meta.value_span();
let help_text = format!("method must be one of: {}", VALID_METHODS_STR); let help_text = format!("method must be one of: {VALID_METHODS_STR}");
if let MetaItem::Path(path) = meta { if let MetaItem::Path(path) = meta {
if let Some(ident) = path.last_ident() { if let Some(ident) = path.last_ident() {
@ -131,19 +131,13 @@ impl FromMeta for Method {
impl ToTokens for Method { impl ToTokens for Method {
fn to_tokens(&self, tokens: &mut TokenStream) { fn to_tokens(&self, tokens: &mut TokenStream) {
let method_tokens = match self.0 { let mut chars = self.0.as_str().chars();
http::Method::Get => quote!(::rocket::http::Method::Get), let variant_str = chars.next()
http::Method::Put => quote!(::rocket::http::Method::Put), .map(|c| c.to_ascii_uppercase().to_string() + &chars.as_str().to_lowercase())
http::Method::Post => quote!(::rocket::http::Method::Post), .unwrap_or_default();
http::Method::Delete => quote!(::rocket::http::Method::Delete),
http::Method::Options => quote!(::rocket::http::Method::Options),
http::Method::Head => quote!(::rocket::http::Method::Head),
http::Method::Trace => quote!(::rocket::http::Method::Trace),
http::Method::Connect => quote!(::rocket::http::Method::Connect),
http::Method::Patch => quote!(::rocket::http::Method::Patch),
};
tokens.extend(method_tokens); let variant = syn::Ident::new(&variant_str, Span::call_site());
tokens.extend(quote!(::rocket::http::Method::#variant));
} }
} }

View File

@ -1497,3 +1497,10 @@ pub fn internal_guide_tests(input: TokenStream) -> TokenStream {
pub fn export(input: TokenStream) -> TokenStream { pub fn export(input: TokenStream) -> TokenStream {
emit!(bang::export_internal(input)) emit!(bang::export_internal(input))
} }
/// Private Rocket attribute: `async_bound(Bounds + On + Returned + Future)`.
#[doc(hidden)]
#[proc_macro_attribute]
pub fn async_bound(args: TokenStream, input: TokenStream) -> TokenStream {
emit!(attribute::async_bound::async_bound(args, input))
}

View File

@ -17,43 +17,22 @@ rust-version = "1.64"
[features] [features]
default = [] default = []
tls = ["rustls", "tokio-rustls", "rustls-pemfile"]
mtls = ["tls", "x509-parser"]
http2 = ["hyper/http2"]
private-cookies = ["cookie/private", "cookie/key-expansion"]
serde = ["uncased/with-serde-alloc", "serde_"] serde = ["uncased/with-serde-alloc", "serde_"]
uuid = ["uuid_"] uuid = ["uuid_"]
[dependencies] [dependencies]
smallvec = { version = "1.11", features = ["const_generics", "const_new"] } smallvec = { version = "1.11", features = ["const_generics", "const_new"] }
percent-encoding = "2" percent-encoding = "2"
http = "0.2"
time = { version = "0.3", features = ["formatting", "macros"] } time = { version = "0.3", features = ["formatting", "macros"] }
indexmap = "2" indexmap = "2"
rustls = { version = "0.22", optional = true }
tokio-rustls = { version = "0.25", optional = true }
rustls-pemfile = { version = "2.0.0", optional = true }
tokio = { version = "1.6.1", features = ["net", "sync", "time"] }
log = "0.4"
ref-cast = "1.0" ref-cast = "1.0"
uncased = "0.9.6" uncased = "0.9.10"
either = "1" either = "1"
pear = "0.2.8" pear = "0.2.8"
pin-project-lite = "0.2"
memchr = "2" memchr = "2"
stable-pattern = "0.1" stable-pattern = "0.1"
cookie = { version = "0.18", features = ["percent-encode"] } cookie = { version = "0.18", features = ["percent-encode"] }
state = "0.6" state = "0.6"
futures = { version = "0.3", default-features = false }
[dependencies.x509-parser]
version = "0.13"
optional = true
[dependencies.hyper]
version = "0.14.9"
default-features = false
features = ["http1", "runtime", "server", "stream"]
[dependencies.serde_] [dependencies.serde_]
package = "serde" package = "serde"

View File

@ -745,8 +745,7 @@ impl<'h> HeaderMap<'h> {
/// WARNING: This is unstable! Do not use this method outside of Rocket! /// WARNING: This is unstable! Do not use this method outside of Rocket!
#[doc(hidden)] #[doc(hidden)]
#[inline] #[inline]
pub fn into_iter_raw(self) pub fn into_iter_raw(self) -> impl Iterator<Item=(Uncased<'h>, Vec<Cow<'h, str>>)> {
-> impl Iterator<Item=(Uncased<'h>, Vec<Cow<'h, str>>)> {
self.headers.into_iter() self.headers.into_iter()
} }
} }

View File

@ -1,35 +0,0 @@
//! Re-exported hyper HTTP library types.
//!
//! All types that are re-exported from Hyper reside inside of this module.
//! These types will, with certainty, be removed with time, but they reside here
//! while necessary.
pub use hyper::{Method, Error, Body, Uri, Version, Request, Response};
pub use hyper::{body, server, service, upgrade};
pub use http::{HeaderValue, request, uri};
/// Reexported Hyper HTTP header types.
pub mod header {
macro_rules! import_http_headers {
($($name:ident),*) => ($(
pub use hyper::header::$name as $name;
)*)
}
import_http_headers! {
ACCEPT, ACCEPT_CHARSET, ACCEPT_ENCODING, ACCEPT_LANGUAGE, ACCEPT_RANGES,
ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS,
ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE,
ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, ALLOW,
AUTHORIZATION, CACHE_CONTROL, CONNECTION, CONTENT_DISPOSITION,
CONTENT_ENCODING, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_LOCATION,
CONTENT_RANGE, CONTENT_SECURITY_POLICY,
CONTENT_SECURITY_POLICY_REPORT_ONLY, CONTENT_TYPE, DATE, ETAG, EXPECT,
EXPIRES, FORWARDED, FROM, HOST, IF_MATCH, IF_MODIFIED_SINCE,
IF_NONE_MATCH, IF_RANGE, IF_UNMODIFIED_SINCE, LAST_MODIFIED, LINK,
LOCATION, ORIGIN, PRAGMA, RANGE, REFERER, REFERRER_POLICY, REFRESH,
STRICT_TRANSPORT_SECURITY, TE, TRANSFER_ENCODING, UPGRADE, USER_AGENT,
VARY
}
}

View File

@ -4,15 +4,11 @@
//! Types that map to concepts in HTTP. //! Types that map to concepts in HTTP.
//! //!
//! This module exports types that map to HTTP concepts or to the underlying //! This module exports types that map to HTTP concepts or to the underlying
//! HTTP library when needed. Because the underlying HTTP library is likely to //! HTTP library when needed.
//! change (see [#17]), types in [`hyper`] should be considered unstable.
//!
//! [#17]: https://github.com/rwf2/Rocket/issues/17
#[macro_use] #[macro_use]
extern crate pear; extern crate pear;
pub mod hyper;
pub mod uri; pub mod uri;
pub mod ext; pub mod ext;
@ -22,7 +18,6 @@ mod method;
mod status; mod status;
mod raw_str; mod raw_str;
mod parse; mod parse;
mod listener;
/// Case-preserving, ASCII case-insensitive string types. /// Case-preserving, ASCII case-insensitive string types.
/// ///
@ -39,14 +34,8 @@ pub mod uncased {
pub mod private { pub mod private {
pub use crate::parse::Indexed; pub use crate::parse::Indexed;
pub use smallvec::{SmallVec, Array}; pub use smallvec::{SmallVec, Array};
pub use crate::listener::{TcpListener, Incoming, Listener, Connection, Certificates};
pub use cookie;
} }
#[doc(hidden)]
#[cfg(feature = "tls")]
pub mod tls;
pub use crate::method::Method; pub use crate::method::Method;
pub use crate::status::{Status, StatusClass}; pub use crate::status::{Status, StatusClass};
pub use crate::raw_str::{RawStr, RawStrBuf}; pub use crate::raw_str::{RawStr, RawStrBuf};

View File

@ -1,257 +0,0 @@
use std::fmt;
use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use std::sync::Arc;
use log::warn;
use tokio::time::Sleep;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use hyper::server::accept::Accept;
use state::InitCell;
pub use tokio::net::TcpListener;
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
#[cfg(not(feature = "tls"))]
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct CertificateDer(pub(crate) Vec<u8>);
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
#[cfg(feature = "tls")]
#[derive(Debug, Clone, Eq, PartialEq)]
#[repr(transparent)]
pub struct CertificateDer(pub(crate) rustls::pki_types::CertificateDer<'static>);
/// A collection of raw certificate data.
#[derive(Clone, Default)]
pub struct Certificates(Arc<InitCell<Vec<CertificateDer>>>);
impl From<Vec<CertificateDer>> for Certificates {
fn from(value: Vec<CertificateDer>) -> Self {
Certificates(Arc::new(value.into()))
}
}
#[cfg(feature = "tls")]
impl From<Vec<rustls::pki_types::CertificateDer<'static>>> for Certificates {
fn from(value: Vec<rustls::pki_types::CertificateDer<'static>>) -> Self {
let value: Vec<_> = value.into_iter().map(CertificateDer).collect();
Certificates(Arc::new(value.into()))
}
}
#[doc(hidden)]
impl Certificates {
/// Set the the raw certificate chain data. Only the first call actually
/// sets the data; the remaining do nothing.
#[cfg(feature = "tls")]
pub(crate) fn set(&self, data: Vec<CertificateDer>) {
self.0.set(data);
}
/// Returns the raw certificate chain data, if any is available.
pub fn chain_data(&self) -> Option<&[CertificateDer]> {
self.0.try_get().map(|v| v.as_slice())
}
}
// TODO.async: 'Listener' and 'Connection' provide common enough functionality
// that they could be introduced in upstream libraries.
/// A 'Listener' yields incoming connections
pub trait Listener {
/// The connection type returned by this listener.
type Connection: Connection;
/// Return the actual address this listener bound to.
fn local_addr(&self) -> Option<SocketAddr>;
/// Try to accept an incoming Connection if ready. This should only return
/// an `Err` when a fatal problem occurs as Hyper kills the server on `Err`.
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<Self::Connection>>;
}
/// A 'Connection' represents an open connection to a client
pub trait Connection: AsyncRead + AsyncWrite {
/// The remote address, i.e. the client's socket address, if it is known.
fn peer_address(&self) -> Option<SocketAddr>;
/// Requests that the connection not delay reading or writing data as much
/// as possible. For connections backed by TCP, this corresponds to setting
/// `TCP_NODELAY`.
fn enable_nodelay(&self) -> io::Result<()>;
/// DER-encoded X.509 certificate chain presented by the client, if any.
///
/// The certificate order must be as it appears in the TLS protocol: the
/// first certificate relates to the peer, the second certifies the first,
/// the third certifies the second, and so on.
///
/// Defaults to an empty vector to indicate that no certificates were
/// presented.
fn peer_certificates(&self) -> Option<Certificates> { None }
}
pin_project_lite::pin_project! {
/// This is a generic version of hyper's AddrIncoming that is intended to be
/// usable with listeners other than a plain TCP stream, e.g. TLS and/or Unix
/// sockets. It does so by bridging the `Listener` trait to what hyper wants (an
/// Accept). This type is internal to Rocket.
#[must_use = "streams do nothing unless polled"]
pub struct Incoming<L> {
sleep_on_errors: Option<Duration>,
nodelay: bool,
#[pin]
pending_error_delay: Option<Sleep>,
#[pin]
listener: L,
}
}
impl<L: Listener> Incoming<L> {
/// Construct an `Incoming` from an existing `Listener`.
pub fn new(listener: L) -> Self {
Self {
listener,
sleep_on_errors: Some(Duration::from_millis(250)),
pending_error_delay: None,
nodelay: false,
}
}
/// Set whether and how long to sleep on accept errors.
///
/// A possible scenario is that the process has hit the max open files
/// allowed, and so trying to accept a new connection will fail with
/// `EMFILE`. In some cases, it's preferable to just wait for some time, if
/// the application will likely close some files (or connections), and try
/// to accept the connection again. If this option is `true`, the error
/// will be logged at the `error` level, since it is still a big deal,
/// and then the listener will sleep for 1 second.
///
/// In other cases, hitting the max open files should be treat similarly
/// to being out-of-memory, and simply error (and shutdown). Setting
/// this option to `None` will allow that.
///
/// Default is 1 second.
pub fn sleep_on_errors(mut self, val: Option<Duration>) -> Self {
self.sleep_on_errors = val;
self
}
/// Set whether to request no delay on all incoming connections. The default
/// is `false`. See [`Connection::enable_nodelay()`] for details.
pub fn nodelay(mut self, nodelay: bool) -> Self {
self.nodelay = nodelay;
self
}
fn poll_accept_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<L::Connection>> {
/// This function defines per-connection errors: errors that affect only
/// a single connection's accept() and don't imply anything about the
/// success probability of the next accept(). Thus, we can attempt to
/// `accept()` another connection immediately. All other errors will
/// incur a delay before the next `accept()` is performed. The delay is
/// useful to handle resource exhaustion errors like ENFILE and EMFILE.
/// Otherwise, could enter into tight loop.
fn is_connection_error(e: &io::Error) -> bool {
matches!(e.kind(),
| io::ErrorKind::ConnectionRefused
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::ConnectionReset)
}
let mut this = self.project();
loop {
// Check if a previous sleep timer is active, set on I/O errors.
if let Some(delay) = this.pending_error_delay.as_mut().as_pin_mut() {
futures::ready!(delay.poll(cx));
}
this.pending_error_delay.set(None);
match futures::ready!(this.listener.as_mut().poll_accept(cx)) {
Ok(stream) => {
if *this.nodelay {
if let Err(e) = stream.enable_nodelay() {
warn!("failed to enable NODELAY: {}", e);
}
}
return Poll::Ready(Ok(stream));
},
Err(e) => {
if is_connection_error(&e) {
warn!("single connection accept error {}; accepting next now", e);
} else if let Some(duration) = this.sleep_on_errors {
// We might be able to recover. Try again in a bit.
warn!("accept error {}; recovery attempt in {}ms", e, duration.as_millis());
this.pending_error_delay.set(Some(tokio::time::sleep(*duration)));
} else {
return Poll::Ready(Err(e));
}
},
}
}
}
}
impl<L: Listener> Accept for Incoming<L> {
type Conn = L::Connection;
type Error = io::Error;
#[inline]
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<Option<io::Result<Self::Conn>>> {
self.poll_accept_next(cx).map(Some)
}
}
impl<L: fmt::Debug> fmt::Debug for Incoming<L> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Incoming")
.field("listener", &self.listener)
.finish()
}
}
impl Listener for TcpListener {
type Connection = TcpStream;
#[inline]
fn local_addr(&self) -> Option<SocketAddr> {
self.local_addr().ok()
}
#[inline]
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<Self::Connection>> {
(*self).poll_accept(cx).map_ok(|(stream, _addr)| stream)
}
}
impl Connection for TcpStream {
#[inline]
fn peer_address(&self) -> Option<SocketAddr> {
self.peer_addr().ok()
}
#[inline]
fn enable_nodelay(&self) -> io::Result<()> {
self.set_nodelay(true)
}
}

View File

@ -3,8 +3,6 @@ use std::str::FromStr;
use self::Method::*; use self::Method::*;
use crate::hyper;
// TODO: Support non-standard methods, here and in codegen? // TODO: Support non-standard methods, here and in codegen?
/// Representation of HTTP methods. /// Representation of HTTP methods.
@ -29,6 +27,7 @@ use crate::hyper;
/// } /// }
/// # } /// # }
/// ``` /// ```
#[repr(u8)]
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub enum Method { pub enum Method {
/// The `GET` variant. /// The `GET` variant.
@ -52,23 +51,6 @@ pub enum Method {
} }
impl Method { impl Method {
/// WARNING: This is unstable! Do not use this method outside of Rocket!
#[doc(hidden)]
pub fn from_hyp(method: &hyper::Method) -> Option<Method> {
match *method {
hyper::Method::GET => Some(Get),
hyper::Method::PUT => Some(Put),
hyper::Method::POST => Some(Post),
hyper::Method::DELETE => Some(Delete),
hyper::Method::OPTIONS => Some(Options),
hyper::Method::HEAD => Some(Head),
hyper::Method::TRACE => Some(Trace),
hyper::Method::CONNECT => Some(Connect),
hyper::Method::PATCH => Some(Patch),
_ => None,
}
}
/// Returns `true` if an HTTP request with the method represented by `self` /// Returns `true` if an HTTP request with the method represented by `self`
/// always supports a payload. /// always supports a payload.
/// ///

View File

@ -1,235 +0,0 @@
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::future::Future;
use std::net::SocketAddr;
use tokio::net::{TcpListener, TcpStream};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::{Accept, TlsAcceptor, server::TlsStream as BareTlsStream};
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};
use crate::tls::util::{load_cert_chain, load_key, load_ca_certs};
use crate::listener::{Connection, Listener, Certificates, CertificateDer};
/// A TLS listener over TCP.
pub struct TlsListener {
listener: TcpListener,
acceptor: TlsAcceptor,
}
/// This implementation exists so that ROCKET_WORKERS=1 can make progress while
/// a TLS handshake is being completed. It does this by returning `Ready` from
/// `poll_accept()` as soon as we have a TCP connection and performing the
/// handshake in the `AsyncRead` and `AsyncWrite` implementations.
///
/// A straight-forward implementation of this strategy results in none of the
/// TLS information being available at the time the connection is "established",
/// that is, when `poll_accept()` returns, since the handshake has yet to occur.
/// Importantly, certificate information isn't available at the time that we
/// request it.
///
/// The underlying problem is hyper's "Accept" trait. Were we to manage
/// connections ourselves, we'd likely want to:
///
/// 1. Stop blocking the worker as soon as we have a TCP connection.
/// 2. Perform the handshake in the background.
/// 3. Give the connection to Rocket when/if the handshake is done.
///
/// See hyperium/hyper/issues/2321 for more details.
///
/// To work around this, we "lie" when `peer_certificates()` are requested and
/// always return `Some(Certificates)`. Internally, `Certificates` is an
/// `Arc<InitCell<Vec<CertificateDer>>>`, effectively a shared, thread-safe,
/// `OnceCell`. The cell is initially empty and is filled as soon as the
/// handshake is complete. If the certificate data were to be requested prior to
/// this point, it would be empty. However, in Rocket, we only request
/// certificate data when we have a `Request` object, which implies we're
/// receiving payload data, which implies the TLS handshake has finished, so the
/// certificate data as seen by a Rocket application will always be "fresh".
pub struct TlsStream {
remote: SocketAddr,
state: TlsState,
certs: Certificates,
}
/// State of `TlsStream`.
pub enum TlsState {
/// The TLS handshake is taking place. We don't have a full connection yet.
Handshaking(Accept<TcpStream>),
/// TLS handshake completed successfully; we're getting payload data.
Streaming(BareTlsStream<TcpStream>),
}
/// TLS as ~configured by `TlsConfig` in `rocket` core.
pub struct Config<R> {
pub cert_chain: R,
pub private_key: R,
pub ciphersuites: Vec<rustls::SupportedCipherSuite>,
pub prefer_server_order: bool,
pub ca_certs: Option<R>,
pub mandatory_mtls: bool,
}
impl TlsListener {
pub async fn bind<R>(addr: SocketAddr, mut c: Config<R>) -> crate::tls::Result<TlsListener>
where R: io::BufRead
{
let provider = rustls::crypto::CryptoProvider {
cipher_suites: c.ciphersuites,
..rustls::crypto::ring::default_provider()
};
let verifier = match c.ca_certs {
Some(ref mut ca_certs) => {
let ca_roots = Arc::new(load_ca_certs(ca_certs)?);
let verifier = WebPkiClientVerifier::builder(ca_roots);
match c.mandatory_mtls {
true => verifier.build()?,
false => verifier.allow_unauthenticated().build()?,
}
},
None => WebPkiClientVerifier::no_client_auth(),
};
let key = load_key(&mut c.private_key)?;
let cert_chain = load_cert_chain(&mut c.cert_chain)?;
let mut config = ServerConfig::builder_with_provider(Arc::new(provider))
.with_safe_default_protocol_versions()?
.with_client_cert_verifier(verifier)
.with_single_cert(cert_chain, key)?;
config.ignore_client_order = c.prefer_server_order;
config.session_storage = ServerSessionMemoryCache::new(1024);
config.ticketer = rustls::crypto::ring::Ticketer::new()?;
config.alpn_protocols = vec![b"http/1.1".to_vec()];
if cfg!(feature = "http2") {
config.alpn_protocols.insert(0, b"h2".to_vec());
}
let listener = TcpListener::bind(addr).await?;
let acceptor = TlsAcceptor::from(Arc::new(config));
Ok(TlsListener { listener, acceptor })
}
}
impl Listener for TlsListener {
type Connection = TlsStream;
fn local_addr(&self) -> Option<SocketAddr> {
self.listener.local_addr().ok()
}
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<Self::Connection>> {
match futures::ready!(self.listener.poll_accept(cx)) {
Ok((io, addr)) => Poll::Ready(Ok(TlsStream {
remote: addr,
state: TlsState::Handshaking(self.acceptor.accept(io)),
// These are empty and filled in after handshake is complete.
certs: Certificates::default(),
})),
Err(e) => Poll::Ready(Err(e)),
}
}
}
impl Connection for TlsStream {
fn peer_address(&self) -> Option<SocketAddr> {
Some(self.remote)
}
fn enable_nodelay(&self) -> io::Result<()> {
// If `Handshaking` is `None`, it either failed, so we returned an `Err`
// from `poll_accept()` and there's no connection to enable `NODELAY`
// on, or it succeeded, so we're in the `Streaming` stage and we have
// infallible access to the connection.
match &self.state {
TlsState::Handshaking(accept) => match accept.get_ref() {
None => Ok(()),
Some(s) => s.enable_nodelay(),
},
TlsState::Streaming(stream) => stream.get_ref().0.enable_nodelay()
}
}
fn peer_certificates(&self) -> Option<Certificates> {
Some(self.certs.clone())
}
}
impl TlsStream {
fn poll_accept_then<F, T>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut f: F
) -> Poll<io::Result<T>>
where F: FnMut(&mut BareTlsStream<TcpStream>, &mut Context<'_>) -> Poll<io::Result<T>>
{
loop {
match self.state {
TlsState::Handshaking(ref mut accept) => {
match futures::ready!(Pin::new(accept).poll(cx)) {
Ok(stream) => {
if let Some(peer_certs) = stream.get_ref().1.peer_certificates() {
self.certs.set(peer_certs.into_iter()
.map(|v| CertificateDer(v.clone().into_owned()))
.collect());
}
self.state = TlsState::Streaming(stream);
}
Err(e) => {
log::warn!("tls handshake with {} failed: {}", self.remote, e);
return Poll::Ready(Err(e));
}
}
},
TlsState::Streaming(ref mut stream) => return f(stream, cx),
}
}
}
}
impl AsyncRead for TlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.poll_accept_then(cx, |stream, cx| Pin::new(stream).poll_read(cx, buf))
}
}
impl AsyncWrite for TlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.poll_accept_then(cx, |stream, cx| Pin::new(stream).poll_write(cx, buf))
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut self.state {
TlsState::Handshaking(accept) => match accept.get_mut() {
Some(io) => Pin::new(io).poll_flush(cx),
None => Poll::Ready(Ok(())),
}
TlsState::Streaming(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut self.state {
TlsState::Handshaking(accept) => match accept.get_mut() {
Some(io) => Pin::new(io).poll_shutdown(cx),
None => Poll::Ready(Ok(())),
}
TlsState::Streaming(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}

View File

@ -1,11 +0,0 @@
mod listener;
#[cfg(feature = "mtls")]
pub mod mtls;
pub use rustls;
pub use listener::{TlsListener, Config};
pub mod util;
pub mod error;
pub use error::Result;

View File

@ -20,23 +20,36 @@ rust-version = "1.64"
all-features = true all-features = true
[features] [features]
default = ["http2"] default = ["http2", "tokio-macros"]
tls = ["rocket_http/tls"] http2 = ["hyper/http2", "hyper-util/http2"]
mtls = ["rocket_http/mtls", "tls"] secrets = ["cookie/private", "cookie/key-expansion"]
http2 = ["rocket_http/http2"] json = ["serde_json"]
secrets = ["rocket_http/private-cookies"] msgpack = ["rmp-serde"]
json = ["serde_json", "tokio/io-util"]
msgpack = ["rmp-serde", "tokio/io-util"]
uuid = ["uuid_", "rocket_http/uuid"] uuid = ["uuid_", "rocket_http/uuid"]
tls = ["rustls", "tokio-rustls", "rustls-pemfile"]
mtls = ["tls", "x509-parser"]
tokio-macros = ["tokio/macros"]
[dependencies] [dependencies]
# Serialization dependencies. # Optional serialization dependencies.
serde_json = { version = "1.0.26", optional = true } serde_json = { version = "1.0.26", optional = true }
rmp-serde = { version = "1", optional = true } rmp-serde = { version = "1", optional = true }
uuid_ = { package = "uuid", version = "1", optional = true, features = ["serde"] } uuid_ = { package = "uuid", version = "1", optional = true, features = ["serde"] }
# Optional TLS dependencies
rustls = { version = "0.22", optional = true }
tokio-rustls = { version = "0.25", optional = true }
rustls-pemfile = { version = "2.0.0", optional = true }
# Optional MTLS dependencies
x509-parser = { version = "0.13", optional = true }
# Hyper dependencies
http = "1"
bytes = "1.4"
hyper = { version = "1.1", default-features = false, features = ["http1", "server"] }
# Non-optional, core dependencies from here on out. # Non-optional, core dependencies from here on out.
futures = { version = "0.3.0", default-features = false, features = ["std"] }
yansi = { version = "1.0.0-rc", features = ["detect-tty"] } yansi = { version = "1.0.0-rc", features = ["detect-tty"] }
log = { version = "0.4", features = ["std"] } log = { version = "0.4", features = ["std"] }
num_cpus = "1.0" num_cpus = "1.0"
@ -44,11 +57,11 @@ time = { version = "0.3", features = ["macros", "parsing"] }
memchr = "2" # TODO: Use pear instead. memchr = "2" # TODO: Use pear instead.
binascii = "0.1" binascii = "0.1"
ref-cast = "1.0" ref-cast = "1.0"
atomic = "0.5" ref-swap = "0.1.2"
parking_lot = "0.12" parking_lot = "0.12"
ubyte = {version = "0.10.2", features = ["serde"] } ubyte = {version = "0.10.2", features = ["serde"] }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
figment = { version = "0.10.6", features = ["toml", "env"] } figment = { version = "0.10.13", features = ["toml", "env"] }
rand = "0.8" rand = "0.8"
either = "1" either = "1"
pin-project-lite = "0.2" pin-project-lite = "0.2"
@ -58,8 +71,25 @@ async-trait = "0.1.43"
async-stream = "0.3.2" async-stream = "0.3.2"
multer = { version = "3.0.0", features = ["tokio-io"] } multer = { version = "3.0.0", features = ["tokio-io"] }
tokio-stream = { version = "0.1.6", features = ["signal", "time"] } tokio-stream = { version = "0.1.6", features = ["signal", "time"] }
cookie = { version = "0.18", features = ["percent-encode"] }
futures = { version = "0.3.30", default-features = false, features = ["std"] }
state = "0.6" state = "0.6"
[dependencies.hyper-util]
git = "https://github.com/SergioBenitez/hyper-util.git"
branch = "fix-readversion"
default-features = false
features = ["http1", "server", "tokio"]
[dependencies.tokio]
version = "1.35.1"
features = ["rt-multi-thread", "net", "io-util", "fs", "time", "sync", "signal", "parking_lot"]
[dependencies.tokio-util]
version = "0.7"
default-features = false
features = ["io"]
[dependencies.rocket_codegen] [dependencies.rocket_codegen]
version = "0.6.0-dev" version = "0.6.0-dev"
path = "../codegen" path = "../codegen"
@ -69,21 +99,13 @@ version = "0.6.0-dev"
path = "../http" path = "../http"
features = ["serde"] features = ["serde"]
[dependencies.tokio] [target.'cfg(unix)'.dependencies]
version = "1.6.1" libc = "0.2.149"
features = ["fs", "io-std", "io-util", "rt-multi-thread", "sync", "signal", "macros"]
[dependencies.tokio-util]
version = "0.7"
default-features = false
features = ["io"]
[dependencies.bytes]
version = "1.0"
[build-dependencies] [build-dependencies]
version_check = "0.9.1" version_check = "0.9.1"
[dev-dependencies] [dev-dependencies]
tokio = { version = "1", features = ["macros", "io-std"] }
figment = { version = "0.10", features = ["test"] } figment = { version = "0.10", features = ["test"] }
pretty_assertions = "1" pretty_assertions = "1"

View File

@ -1,5 +1,3 @@
use std::net::{IpAddr, Ipv4Addr};
use figment::{Figment, Profile, Provider, Metadata, error::Result}; use figment::{Figment, Profile, Provider, Metadata, error::Result};
use figment::providers::{Serialized, Env, Toml, Format}; use figment::providers::{Serialized, Env, Toml, Format};
use figment::value::{Map, Dict, magic::RelativePathBuf}; use figment::value::{Map, Dict, magic::RelativePathBuf};
@ -12,9 +10,6 @@ use crate::request::{self, Request, FromRequest};
use crate::http::uncased::Uncased; use crate::http::uncased::Uncased;
use crate::data::Limits; use crate::data::Limits;
#[cfg(feature = "tls")]
use crate::config::TlsConfig;
#[cfg(feature = "secrets")] #[cfg(feature = "secrets")]
use crate::config::SecretKey; use crate::config::SecretKey;
@ -66,10 +61,6 @@ pub struct Config {
/// set to the extracting Figment's selected `Profile`._ /// set to the extracting Figment's selected `Profile`._
#[serde(skip)] #[serde(skip)]
pub profile: Profile, pub profile: Profile,
/// IP address to serve on. **(default: `127.0.0.1`)**
pub address: IpAddr,
/// Port to serve on. **(default: `8000`)**
pub port: u16,
/// Number of threads to use for executing futures. **(default: `num_cores`)** /// Number of threads to use for executing futures. **(default: `num_cores`)**
/// ///
/// _**Note:** Rocket only reads this value from sources in the [default /// _**Note:** Rocket only reads this value from sources in the [default
@ -121,10 +112,6 @@ pub struct Config {
pub temp_dir: RelativePathBuf, pub temp_dir: RelativePathBuf,
/// Keep-alive timeout in seconds; disabled when `0`. **(default: `5`)** /// Keep-alive timeout in seconds; disabled when `0`. **(default: `5`)**
pub keep_alive: u32, pub keep_alive: u32,
/// The TLS configuration, if any. **(default: `None`)**
#[cfg(feature = "tls")]
#[cfg_attr(nightly, doc(cfg(feature = "tls")))]
pub tls: Option<TlsConfig>,
/// The secret key for signing and encrypting. **(default: `0`)** /// The secret key for signing and encrypting. **(default: `0`)**
/// ///
/// _**Note:** This field _always_ serializes as a 256-bit array of `0`s to /// _**Note:** This field _always_ serializes as a 256-bit array of `0`s to
@ -148,7 +135,6 @@ pub struct Config {
/// use rocket::Config; /// use rocket::Config;
/// ///
/// let config = Config { /// let config = Config {
/// port: 1024,
/// keep_alive: 10, /// keep_alive: 10,
/// ..Default::default() /// ..Default::default()
/// }; /// };
@ -204,8 +190,6 @@ impl Config {
pub fn debug_default() -> Config { pub fn debug_default() -> Config {
Config { Config {
profile: Self::DEBUG_PROFILE, profile: Self::DEBUG_PROFILE,
address: Ipv4Addr::new(127, 0, 0, 1).into(),
port: 8000,
workers: num_cpus::get(), workers: num_cpus::get(),
max_blocking: 512, max_blocking: 512,
ident: Ident::default(), ident: Ident::default(),
@ -214,8 +198,6 @@ impl Config {
limits: Limits::default(), limits: Limits::default(),
temp_dir: std::env::temp_dir().into(), temp_dir: std::env::temp_dir().into(),
keep_alive: 5, keep_alive: 5,
#[cfg(feature = "tls")]
tls: None,
#[cfg(feature = "secrets")] #[cfg(feature = "secrets")]
secret_key: SecretKey::zero(), secret_key: SecretKey::zero(),
shutdown: Shutdown::default(), shutdown: Shutdown::default(),
@ -331,59 +313,6 @@ impl Config {
Self::try_from(provider).unwrap_or_else(bail_with_config_error) Self::try_from(provider).unwrap_or_else(bail_with_config_error)
} }
/// Returns `true` if TLS is enabled.
///
/// TLS is enabled when the `tls` feature is enabled and TLS has been
/// configured with at least one ciphersuite. Note that without changing
/// defaults, all supported ciphersuites are enabled in the recommended
/// configuration.
///
/// # Example
///
/// ```rust
/// let config = rocket::Config::default();
/// if config.tls_enabled() {
/// println!("TLS is enabled!");
/// } else {
/// println!("TLS is disabled.");
/// }
/// ```
pub fn tls_enabled(&self) -> bool {
#[cfg(feature = "tls")] {
self.tls.as_ref().map_or(false, |tls| !tls.ciphers.is_empty())
}
#[cfg(not(feature = "tls"))] { false }
}
/// Returns `true` if mTLS is enabled.
///
/// mTLS is enabled when TLS is enabled ([`Config::tls_enabled()`]) _and_
/// the `mtls` feature is enabled _and_ mTLS has been configured with a CA
/// certificate chain.
///
/// # Example
///
/// ```rust
/// let config = rocket::Config::default();
/// if config.mtls_enabled() {
/// println!("mTLS is enabled!");
/// } else {
/// println!("mTLS is disabled.");
/// }
/// ```
pub fn mtls_enabled(&self) -> bool {
if !self.tls_enabled() {
return false;
}
#[cfg(feature = "mtls")] {
self.tls.as_ref().map_or(false, |tls| tls.mutual.is_some())
}
#[cfg(not(feature = "mtls"))] { false }
}
#[cfg(feature = "secrets")] #[cfg(feature = "secrets")]
pub(crate) fn known_secret_key_used(&self) -> bool { pub(crate) fn known_secret_key_used(&self) -> bool {
const KNOWN_SECRET_KEYS: &'static [&'static str] = &[ const KNOWN_SECRET_KEYS: &'static [&'static str] = &[
@ -420,8 +349,6 @@ impl Config {
self.trace_print(figment); self.trace_print(figment);
launch_meta!("{}Configured for {}.", "🔧 ".emoji(), self.profile.underline()); launch_meta!("{}Configured for {}.", "🔧 ".emoji(), self.profile.underline());
launch_meta_!("address: {}", self.address.paint(VAL));
launch_meta_!("port: {}", self.port.paint(VAL));
launch_meta_!("workers: {}", self.workers.paint(VAL)); launch_meta_!("workers: {}", self.workers.paint(VAL));
launch_meta_!("max blocking threads: {}", self.max_blocking.paint(VAL)); launch_meta_!("max blocking threads: {}", self.max_blocking.paint(VAL));
launch_meta_!("ident: {}", self.ident.paint(VAL)); launch_meta_!("ident: {}", self.ident.paint(VAL));
@ -445,12 +372,6 @@ impl Config {
ka => launch_meta_!("keep-alive: {}{}", ka.paint(VAL), "s".paint(VAL)), ka => launch_meta_!("keep-alive: {}{}", ka.paint(VAL), "s".paint(VAL)),
} }
match (self.tls_enabled(), self.mtls_enabled()) {
(true, true) => launch_meta_!("tls: {}", "enabled w/mtls".paint(VAL)),
(true, false) => launch_meta_!("tls: {} w/o mtls", "enabled".paint(VAL)),
(false, _) => launch_meta_!("tls: {}", "disabled".paint(VAL)),
}
launch_meta_!("shutdown: {}", self.shutdown.paint(VAL)); launch_meta_!("shutdown: {}", self.shutdown.paint(VAL));
launch_meta_!("log level: {}", self.log_level.paint(VAL)); launch_meta_!("log level: {}", self.log_level.paint(VAL));
launch_meta_!("cli colors: {}", self.cli_colors.paint(VAL)); launch_meta_!("cli colors: {}", self.cli_colors.paint(VAL));
@ -519,12 +440,6 @@ impl Config {
/// This isn't `pub` because setting it directly does nothing. /// This isn't `pub` because setting it directly does nothing.
const PROFILE: &'static str = "profile"; const PROFILE: &'static str = "profile";
/// The stringy parameter name for setting/extracting [`Config::address`].
pub const ADDRESS: &'static str = "address";
/// The stringy parameter name for setting/extracting [`Config::port`].
pub const PORT: &'static str = "port";
/// The stringy parameter name for setting/extracting [`Config::workers`]. /// The stringy parameter name for setting/extracting [`Config::workers`].
pub const WORKERS: &'static str = "workers"; pub const WORKERS: &'static str = "workers";
@ -546,9 +461,6 @@ impl Config {
/// The stringy parameter name for setting/extracting [`Config::limits`]. /// The stringy parameter name for setting/extracting [`Config::limits`].
pub const LIMITS: &'static str = "limits"; pub const LIMITS: &'static str = "limits";
/// The stringy parameter name for setting/extracting [`Config::tls`].
pub const TLS: &'static str = "tls";
/// The stringy parameter name for setting/extracting [`Config::secret_key`]. /// The stringy parameter name for setting/extracting [`Config::secret_key`].
pub const SECRET_KEY: &'static str = "secret_key"; pub const SECRET_KEY: &'static str = "secret_key";
@ -566,9 +478,10 @@ impl Config {
/// An array of all of the stringy parameter names. /// An array of all of the stringy parameter names.
pub const PARAMETERS: &'static [&'static str] = &[ pub const PARAMETERS: &'static [&'static str] = &[
Self::ADDRESS, Self::PORT, Self::WORKERS, Self::MAX_BLOCKING, Self::KEEP_ALIVE, Self::WORKERS, Self::MAX_BLOCKING, Self::KEEP_ALIVE, Self::IDENT,
Self::IDENT, Self::IP_HEADER, Self::PROXY_PROTO_HEADER, Self::LIMITS, Self::TLS, Self::IP_HEADER, Self::PROXY_PROTO_HEADER, Self::LIMITS,
Self::SECRET_KEY, Self::TEMP_DIR, Self::LOG_LEVEL, Self::SHUTDOWN, Self::CLI_COLORS, Self::SECRET_KEY, Self::TEMP_DIR, Self::LOG_LEVEL, Self::SHUTDOWN,
Self::CLI_COLORS,
]; ];
} }

View File

@ -117,9 +117,6 @@ mod shutdown;
mod cli_colors; mod cli_colors;
mod http_header; mod http_header;
#[cfg(feature = "tls")]
mod tls;
#[cfg(feature = "secrets")] #[cfg(feature = "secrets")]
mod secret_key; mod secret_key;
@ -132,12 +129,6 @@ pub use shutdown::Shutdown;
pub use ident::Ident; pub use ident::Ident;
pub use cli_colors::CliColors; pub use cli_colors::CliColors;
#[cfg(feature = "tls")]
pub use tls::{TlsConfig, CipherSuite};
#[cfg(feature = "mtls")]
pub use tls::MutualTls;
#[cfg(feature = "secrets")] #[cfg(feature = "secrets")]
pub use secret_key::SecretKey; pub use secret_key::SecretKey;
@ -146,7 +137,6 @@ pub use shutdown::Sig;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::net::Ipv4Addr;
use figment::{Figment, Profile}; use figment::{Figment, Profile};
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
@ -202,9 +192,7 @@ mod tests {
figment::Jail::expect_with(|jail| { figment::Jail::expect_with(|jail| {
jail.create_file("Rocket.toml", r#" jail.create_file("Rocket.toml", r#"
[default] [default]
address = "1.2.3.4"
ident = "Something Cool" ident = "Something Cool"
port = 1234
workers = 20 workers = 20
keep_alive = 10 keep_alive = 10
log_level = "off" log_level = "off"
@ -213,8 +201,6 @@ mod tests {
let config = Config::from(Config::figment()); let config = Config::from(Config::figment());
assert_eq!(config, Config { assert_eq!(config, Config {
address: Ipv4Addr::new(1, 2, 3, 4).into(),
port: 1234,
workers: 20, workers: 20,
ident: ident!("Something Cool"), ident: ident!("Something Cool"),
keep_alive: 10, keep_alive: 10,
@ -225,9 +211,7 @@ mod tests {
jail.create_file("Rocket.toml", r#" jail.create_file("Rocket.toml", r#"
[global] [global]
address = "1.2.3.4"
ident = "Something Else Cool" ident = "Something Else Cool"
port = 1234
workers = 20 workers = 20
keep_alive = 10 keep_alive = 10
log_level = "off" log_level = "off"
@ -236,8 +220,6 @@ mod tests {
let config = Config::from(Config::figment()); let config = Config::from(Config::figment());
assert_eq!(config, Config { assert_eq!(config, Config {
address: Ipv4Addr::new(1, 2, 3, 4).into(),
port: 1234,
workers: 20, workers: 20,
ident: ident!("Something Else Cool"), ident: ident!("Something Else Cool"),
keep_alive: 10, keep_alive: 10,
@ -249,8 +231,6 @@ mod tests {
jail.set_env("ROCKET_CONFIG", "Other.toml"); jail.set_env("ROCKET_CONFIG", "Other.toml");
jail.create_file("Other.toml", r#" jail.create_file("Other.toml", r#"
[default] [default]
address = "1.2.3.4"
port = 1234
workers = 20 workers = 20
keep_alive = 10 keep_alive = 10
log_level = "off" log_level = "off"
@ -259,8 +239,6 @@ mod tests {
let config = Config::from(Config::figment()); let config = Config::from(Config::figment());
assert_eq!(config, Config { assert_eq!(config, Config {
address: Ipv4Addr::new(1, 2, 3, 4).into(),
port: 1234,
workers: 20, workers: 20,
keep_alive: 10, keep_alive: 10,
log_level: LogLevel::Off, log_level: LogLevel::Off,
@ -367,228 +345,6 @@ mod tests {
}) })
} }
#[test]
#[cfg(feature = "tls")]
fn test_tls_config_from_file() {
use crate::config::{TlsConfig, CipherSuite, Ident, Shutdown};
figment::Jail::expect_with(|jail| {
jail.create_file("Rocket.toml", r#"
[global]
shutdown.ctrlc = 0
ident = false
[global.tls]
certs = "/ssl/cert.pem"
key = "/ssl/key.pem"
[global.limits]
forms = "1mib"
json = "10mib"
stream = "50kib"
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config, Config {
shutdown: Shutdown { ctrlc: false, ..Default::default() },
ident: Ident::none(),
tls: Some(TlsConfig::from_paths("/ssl/cert.pem", "/ssl/key.pem")),
limits: Limits::default()
.limit("forms", 1.mebibytes())
.limit("json", 10.mebibytes())
.limit("stream", 50.kibibytes()),
..Config::default()
});
jail.create_file("Rocket.toml", r#"
[global.tls]
certs = "cert.pem"
key = "key.pem"
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config, Config {
tls: Some(TlsConfig::from_paths(
jail.directory().join("cert.pem"),
jail.directory().join("key.pem")
)),
..Config::default()
});
jail.create_file("Rocket.toml", r#"
[global.tls]
certs = "cert.pem"
key = "key.pem"
prefer_server_cipher_order = true
ciphers = [
"TLS_CHACHA20_POLY1305_SHA256",
"TLS_AES_256_GCM_SHA384",
"TLS_AES_128_GCM_SHA256",
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
]
"#)?;
let config = Config::from(Config::figment());
let cert_path = jail.directory().join("cert.pem");
let key_path = jail.directory().join("key.pem");
assert_eq!(config, Config {
tls: Some(TlsConfig::from_paths(cert_path, key_path)
.with_preferred_server_cipher_order(true)
.with_ciphers([
CipherSuite::TLS_CHACHA20_POLY1305_SHA256,
CipherSuite::TLS_AES_256_GCM_SHA384,
CipherSuite::TLS_AES_128_GCM_SHA256,
CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
])),
..Config::default()
});
jail.create_file("Rocket.toml", r#"
[global]
shutdown.ctrlc = 0
ident = false
[global.tls]
certs = "/ssl/cert.pem"
key = "/ssl/key.pem"
[global.limits]
forms = "1mib"
json = "10mib"
stream = "50kib"
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config, Config {
shutdown: Shutdown { ctrlc: false, ..Default::default() },
ident: Ident::none(),
tls: Some(TlsConfig::from_paths("/ssl/cert.pem", "/ssl/key.pem")),
limits: Limits::default()
.limit("forms", 1.mebibytes())
.limit("json", 10.mebibytes())
.limit("stream", 50.kibibytes()),
..Config::default()
});
jail.create_file("Rocket.toml", r#"
[global.tls]
certs = "cert.pem"
key = "key.pem"
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config, Config {
tls: Some(TlsConfig::from_paths(
jail.directory().join("cert.pem"),
jail.directory().join("key.pem")
)),
..Config::default()
});
jail.create_file("Rocket.toml", r#"
[global.tls]
certs = "cert.pem"
key = "key.pem"
prefer_server_cipher_order = true
ciphers = [
"TLS_CHACHA20_POLY1305_SHA256",
"TLS_AES_256_GCM_SHA384",
"TLS_AES_128_GCM_SHA256",
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
]
"#)?;
let config = Config::from(Config::figment());
let cert_path = jail.directory().join("cert.pem");
let key_path = jail.directory().join("key.pem");
assert_eq!(config, Config {
tls: Some(TlsConfig::from_paths(cert_path, key_path)
.with_preferred_server_cipher_order(true)
.with_ciphers([
CipherSuite::TLS_CHACHA20_POLY1305_SHA256,
CipherSuite::TLS_AES_256_GCM_SHA384,
CipherSuite::TLS_AES_128_GCM_SHA256,
CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
])),
..Config::default()
});
Ok(())
});
}
#[test]
#[cfg(feature = "mtls")]
fn test_mtls_config() {
use std::path::Path;
figment::Jail::expect_with(|jail| {
jail.create_file("Rocket.toml", r#"
[default.tls]
certs = "/ssl/cert.pem"
key = "/ssl/key.pem"
"#)?;
let config = Config::from(Config::figment());
assert!(config.tls.is_some());
assert!(config.tls.as_ref().unwrap().mutual.is_none());
assert!(config.tls_enabled());
assert!(!config.mtls_enabled());
jail.create_file("Rocket.toml", r#"
[default.tls]
certs = "/ssl/cert.pem"
key = "/ssl/key.pem"
mutual = { ca_certs = "/ssl/ca.pem" }
"#)?;
let config = Config::from(Config::figment());
assert!(config.tls_enabled());
assert!(config.mtls_enabled());
let mtls = config.tls.as_ref().unwrap().mutual.as_ref().unwrap();
assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem"));
assert!(!mtls.mandatory);
jail.create_file("Rocket.toml", r#"
[default.tls]
certs = "/ssl/cert.pem"
key = "/ssl/key.pem"
[default.tls.mutual]
ca_certs = "/ssl/ca.pem"
mandatory = true
"#)?;
let config = Config::from(Config::figment());
let mtls = config.tls.as_ref().unwrap().mutual.as_ref().unwrap();
assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem"));
assert!(mtls.mandatory);
jail.create_file("Rocket.toml", r#"
[default.tls]
certs = "/ssl/cert.pem"
key = "/ssl/key.pem"
mutual = { ca_certs = "relative/ca.pem" }
"#)?;
let config = Config::from(Config::figment());
let mtls = config.tls.as_ref().unwrap().mutual().unwrap();
assert_eq!(mtls.ca_certs().unwrap_left(),
jail.directory().join("relative/ca.pem"));
Ok(())
});
}
#[test] #[test]
fn test_profiles_merge() { fn test_profiles_merge() {
figment::Jail::expect_with(|jail| { figment::Jail::expect_with(|jail| {
@ -629,42 +385,41 @@ mod tests {
} }
#[test] #[test]
#[cfg(feature = "tls")]
fn test_env_vars_merge() { fn test_env_vars_merge() {
use crate::config::{TlsConfig, Ident}; use crate::config::{Ident, Shutdown};
figment::Jail::expect_with(|jail| { figment::Jail::expect_with(|jail| {
jail.set_env("ROCKET_PORT", 9999); jail.set_env("ROCKET_KEEP_ALIVE", 9999);
let config = Config::from(Config::figment()); let config = Config::from(Config::figment());
assert_eq!(config, Config { assert_eq!(config, Config {
port: 9999, keep_alive: 9999,
..Config::default() ..Config::default()
}); });
jail.set_env("ROCKET_TLS", r#"{certs="certs.pem"}"#); jail.set_env("ROCKET_SHUTDOWN", r#"{grace=7}"#);
let first_figment = Config::figment(); let first_figment = Config::figment();
jail.set_env("ROCKET_TLS", r#"{key="key.pem"}"#); jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=10}"#);
let prev_figment = Config::figment().join(&first_figment); let prev_figment = Config::figment().join(&first_figment);
let config = Config::from(&prev_figment); let config = Config::from(&prev_figment);
assert_eq!(config, Config { assert_eq!(config, Config {
port: 9999, keep_alive: 9999,
tls: Some(TlsConfig::from_paths("certs.pem", "key.pem")), shutdown: Shutdown { grace: 7, mercy: 10, ..Default::default() },
..Config::default() ..Config::default()
}); });
jail.set_env("ROCKET_TLS", r#"{certs="new.pem"}"#); jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=20}"#);
let config = Config::from(Config::figment().join(&prev_figment)); let config = Config::from(Config::figment().join(&prev_figment));
assert_eq!(config, Config { assert_eq!(config, Config {
port: 9999, keep_alive: 9999,
tls: Some(TlsConfig::from_paths("new.pem", "key.pem")), shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() },
..Config::default() ..Config::default()
}); });
jail.set_env("ROCKET_LIMITS", r#"{stream=100kiB}"#); jail.set_env("ROCKET_LIMITS", r#"{stream=100kiB}"#);
let config = Config::from(Config::figment().join(&prev_figment)); let config = Config::from(Config::figment().join(&prev_figment));
assert_eq!(config, Config { assert_eq!(config, Config {
port: 9999, keep_alive: 9999,
tls: Some(TlsConfig::from_paths("new.pem", "key.pem")), shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() },
limits: Limits::default().limit("stream", 100.kibibytes()), limits: Limits::default().limit("stream", 100.kibibytes()),
..Config::default() ..Config::default()
}); });
@ -672,8 +427,8 @@ mod tests {
jail.set_env("ROCKET_IDENT", false); jail.set_env("ROCKET_IDENT", false);
let config = Config::from(Config::figment().join(&prev_figment)); let config = Config::from(Config::figment().join(&prev_figment));
assert_eq!(config, Config { assert_eq!(config, Config {
port: 9999, keep_alive: 9999,
tls: Some(TlsConfig::from_paths("new.pem", "key.pem")), shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() },
limits: Limits::default().limit("stream", 100.kibibytes()), limits: Limits::default().limit("stream", 100.kibibytes()),
ident: Ident::none(), ident: Ident::none(),
..Config::default() ..Config::default()

View File

@ -1,8 +1,8 @@
use std::fmt; use std::fmt;
use cookie::Key;
use serde::{de, ser, Deserialize, Serialize}; use serde::{de, ser, Deserialize, Serialize};
use crate::http::private::cookie::Key;
use crate::request::{Outcome, Request, FromRequest}; use crate::request::{Outcome, Request, FromRequest};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]

View File

@ -1,4 +1,4 @@
use std::fmt; use std::{fmt, time::Duration};
#[cfg(unix)] #[cfg(unix)]
use std::collections::HashSet; use std::collections::HashSet;
@ -291,6 +291,14 @@ impl Default for Shutdown {
} }
impl Shutdown { impl Shutdown {
pub(crate) fn grace(&self) -> Duration {
Duration::from_secs(self.grace as u64)
}
pub(crate) fn mercy(&self) -> Duration {
Duration::from_secs(self.mercy as u64)
}
#[cfg(unix)] #[cfg(unix)]
pub(crate) fn signal_stream(&self) -> Option<impl Stream<Item = Sig>> { pub(crate) fn signal_stream(&self) -> Option<impl Stream<Item = Sig>> {
use tokio_stream::{StreamExt, StreamMap, wrappers::SignalStream}; use tokio_stream::{StreamExt, StreamMap, wrappers::SignalStream};

View File

@ -3,16 +3,16 @@ use std::task::{Context, Poll};
use std::path::Path; use std::path::Path;
use std::io::{self, Cursor}; use std::io::{self, Cursor};
use futures::ready;
use futures::stream::Stream;
use tokio::fs::File; use tokio::fs::File;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, ReadBuf, Take}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, ReadBuf, Take};
use tokio_util::io::StreamReader; use tokio_util::io::StreamReader;
use futures::{ready, stream::Stream}; use hyper::body::{Body, Bytes, Incoming as HyperBody};
use crate::http::hyper;
use crate::ext::{PollExt, Chain};
use crate::data::{Capped, N}; use crate::data::{Capped, N};
use crate::http::hyper::body::Bytes;
use crate::data::transform::Transform; use crate::data::transform::Transform;
use crate::util::Chain;
use super::peekable::Peekable; use super::peekable::Peekable;
use super::transform::TransformBuf; use super::transform::TransformBuf;
@ -68,7 +68,7 @@ pub type RawReader<'r> = StreamReader<RawStream<'r>, Bytes>;
/// Raw underlying data stream. /// Raw underlying data stream.
pub enum RawStream<'r> { pub enum RawStream<'r> {
Empty, Empty,
Body(&'r mut hyper::Body), Body(&'r mut HyperBody),
Multipart(multer::Field<'r>), Multipart(multer::Field<'r>),
} }
@ -154,8 +154,14 @@ impl<'r> DataStream<'r> {
/// ``` /// ```
pub fn hint(&self) -> usize { pub fn hint(&self) -> usize {
let base = self.base(); let base = self.base();
let buf_len = base.get_ref().get_ref().0.get_ref().len(); if let (Some(cursor), _) = base.get_ref().get_ref() {
std::cmp::min(buf_len, base.limit() as usize) let len = cursor.get_ref().len() as u64;
let position = cursor.position().min(len);
let remaining = len - position;
remaining.min(base.limit()) as usize
} else {
0
}
} }
/// A helper method to write the body of the request to any `AsyncWrite` /// A helper method to write the body of the request to any `AsyncWrite`
@ -331,17 +337,25 @@ impl Stream for RawStream<'_> {
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.get_mut() { match self.get_mut() {
RawStream::Body(body) => Pin::new(body).poll_next(cx) // TODO: Expose trailer headers, somehow.
.map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)), RawStream::Body(body) => {
RawStream::Multipart(mp) => Pin::new(mp).poll_next(cx) Pin::new(body)
.map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)), .poll_frame(cx)
.map_ok(|frame| frame.into_data().unwrap_or_else(|_| Bytes::new()))
.map_err(io::Error::other)
}
RawStream::Multipart(s) => Pin::new(s).poll_next(cx).map_err(io::Error::other),
RawStream::Empty => Poll::Ready(None), RawStream::Empty => Poll::Ready(None),
} }
} }
fn size_hint(&self) -> (usize, Option<usize>) { fn size_hint(&self) -> (usize, Option<usize>) {
match self { match self {
RawStream::Body(body) => body.size_hint(), RawStream::Body(body) => {
let hint = body.size_hint();
let (lower, upper) = (hint.lower(), hint.upper());
(lower as usize, upper.map(|x| x as usize))
},
RawStream::Multipart(mp) => mp.size_hint(), RawStream::Multipart(mp) => mp.size_hint(),
RawStream::Empty => (0, Some(0)), RawStream::Empty => (0, Some(0)),
} }
@ -358,8 +372,8 @@ impl std::fmt::Display for RawStream<'_> {
} }
} }
impl<'r> From<&'r mut hyper::Body> for RawStream<'r> { impl<'r> From<&'r mut HyperBody> for RawStream<'r> {
fn from(value: &'r mut hyper::Body) -> Self { fn from(value: &'r mut HyperBody) -> Self {
Self::Body(value) Self::Body(value)
} }
} }

View File

@ -3,8 +3,8 @@ use std::task::{Context, Poll};
use std::pin::Pin; use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use hyper::upgrade::Upgraded;
use crate::http::hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo;
/// A bidirectional, raw stream to the client. /// A bidirectional, raw stream to the client.
/// ///
@ -28,7 +28,7 @@ pub struct IoStream {
/// Just in case we want to add stream kinds in the future. /// Just in case we want to add stream kinds in the future.
enum IoStreamKind { enum IoStreamKind {
Upgraded(Upgraded) Upgraded(TokioIo<Upgraded>)
} }
/// An upgraded connection I/O handler. /// An upgraded connection I/O handler.
@ -51,7 +51,7 @@ enum IoStreamKind {
/// ///
/// #[rocket::async_trait] /// #[rocket::async_trait]
/// impl IoHandler for EchoHandler { /// impl IoHandler for EchoHandler {
/// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> { /// async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
/// let (mut reader, mut writer) = io::split(io); /// let (mut reader, mut writer) = io::split(io);
/// io::copy(&mut reader, &mut writer).await?; /// io::copy(&mut reader, &mut writer).await?;
/// Ok(()) /// Ok(())
@ -68,13 +68,20 @@ enum IoStreamKind {
#[crate::async_trait] #[crate::async_trait]
pub trait IoHandler: Send { pub trait IoHandler: Send {
/// Performs the raw I/O. /// Performs the raw I/O.
async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()>; async fn io(self: Box<Self>, io: IoStream) -> io::Result<()>;
}
#[crate::async_trait]
impl IoHandler for () {
async fn io(self: Box<Self>, _: IoStream) -> io::Result<()> {
Ok(())
}
} }
#[doc(hidden)] #[doc(hidden)]
impl From<Upgraded> for IoStream { impl From<Upgraded> for IoStream {
fn from(io: Upgraded) -> Self { fn from(io: Upgraded) -> Self {
IoStream { kind: IoStreamKind::Upgraded(io) } IoStream { kind: IoStreamKind::Upgraded(TokioIo::new(io)) }
} }
} }

View File

@ -178,7 +178,7 @@ impl<'a, 'b> DerefMut for TransformBuf<'a, 'b> {
#[allow(deprecated)] #[allow(deprecated)]
mod tests { mod tests {
use std::hash::SipHasher; use std::hash::SipHasher;
use std::sync::{Arc, atomic::{AtomicU64, AtomicU8}}; use std::sync::{Arc, atomic::{AtomicU8, AtomicU64, Ordering}};
use parking_lot::Mutex; use parking_lot::Mutex;
use ubyte::ToByteUnit; use ubyte::ToByteUnit;
@ -264,41 +264,41 @@ mod tests {
assert_eq!(bytes.len(), 8); assert_eq!(bytes.len(), 8);
let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]"); let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
let value = u64::from_be_bytes(bytes); let value = u64::from_be_bytes(bytes);
hash1.store(value, atomic::Ordering::Release); hash1.store(value, Ordering::Release);
}) })
.chain_inspect(move |bytes| { .chain_inspect(move |bytes| {
assert_eq!(bytes.len(), 8); assert_eq!(bytes.len(), 8);
let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]"); let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
let value = u64::from_be_bytes(bytes); let value = u64::from_be_bytes(bytes);
let prev = hash2.load(atomic::Ordering::Acquire); let prev = hash2.load(Ordering::Acquire);
assert_eq!(prev, value); assert_eq!(prev, value);
inspect2.fetch_add(1, atomic::Ordering::Release); inspect2.fetch_add(1, Ordering::Release);
}); });
}))); })));
// Make sure nothing has happened yet. // Make sure nothing has happened yet.
assert!(raw_data.lock().is_empty()); assert!(raw_data.lock().is_empty());
assert_eq!(hash.load(atomic::Ordering::Acquire), 0); assert_eq!(hash.load(Ordering::Acquire), 0);
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 0); assert_eq!(inspect2.load(Ordering::Acquire), 0);
// Check that nothing happens if the data isn't read. // Check that nothing happens if the data isn't read.
let client = Client::debug(rocket).unwrap(); let client = Client::debug(rocket).unwrap();
client.get("/").body("Hello, world!").dispatch(); client.get("/").body("Hello, world!").dispatch();
assert!(raw_data.lock().is_empty()); assert!(raw_data.lock().is_empty());
assert_eq!(hash.load(atomic::Ordering::Acquire), 0); assert_eq!(hash.load(Ordering::Acquire), 0);
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 0); assert_eq!(inspect2.load(Ordering::Acquire), 0);
// Check inspect + hash + inspect + inspect. // Check inspect + hash + inspect + inspect.
client.post("/").body("Hello, world!").dispatch(); client.post("/").body("Hello, world!").dispatch();
assert_eq!(raw_data.lock().as_slice(), "Hello, world!".as_bytes()); assert_eq!(raw_data.lock().as_slice(), "Hello, world!".as_bytes());
assert_eq!(hash.load(atomic::Ordering::Acquire), 0xae5020d7cf49d14f); assert_eq!(hash.load(Ordering::Acquire), 0xae5020d7cf49d14f);
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 1); assert_eq!(inspect2.load(Ordering::Acquire), 1);
// Check inspect + hash + inspect + inspect, round 2. // Check inspect + hash + inspect + inspect, round 2.
let string = "Rocket, Rocket, where art thee? Oh, tis in the sky, I see!"; let string = "Rocket, Rocket, where art thee? Oh, tis in the sky, I see!";
client.post("/").body(string).dispatch(); client.post("/").body(string).dispatch();
assert_eq!(raw_data.lock().as_slice(), string.as_bytes()); assert_eq!(raw_data.lock().as_slice(), string.as_bytes());
assert_eq!(hash.load(atomic::Ordering::Acquire), 0x323f9aa98f907faf); assert_eq!(hash.load(Ordering::Acquire), 0x323f9aa98f907faf);
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 2); assert_eq!(inspect2.load(Ordering::Acquire), 2);
} }
} }

193
core/lib/src/erased.rs Normal file
View File

@ -0,0 +1,193 @@
use std::io;
use std::mem::transmute;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Poll, Context};
use futures::future::BoxFuture;
use http::request::Parts;
use hyper::body::Incoming;
use tokio::io::{AsyncRead, ReadBuf};
use crate::data::{Data, IoHandler};
use crate::{Request, Response, Rocket, Orbit};
// TODO: Magic with trait async fn to get rid of the box pin.
// TODO: Write safety proofs.
macro_rules! static_assert_covariance {
($T:tt) => (
const _: () = {
fn _assert_covariance<'x: 'y, 'y>(x: &'y $T<'x>) -> &'y $T<'y> { x }
};
)
}
#[derive(Debug)]
pub struct ErasedRequest {
// XXX: SAFETY: This (dependent) field must come first due to drop order!
request: Request<'static>,
_rocket: Arc<Rocket<Orbit>>,
_parts: Box<Parts>,
}
impl Drop for ErasedRequest {
fn drop(&mut self) { }
}
#[derive(Debug)]
pub struct ErasedResponse {
// XXX: SAFETY: This (dependent) field must come first due to drop order!
response: Response<'static>,
_request: Arc<ErasedRequest>,
_incoming: Box<Incoming>,
}
impl Drop for ErasedResponse {
fn drop(&mut self) { }
}
pub struct ErasedIoHandler {
// XXX: SAFETY: This (dependent) field must come first due to drop order!
io: Box<dyn IoHandler + 'static>,
_request: Arc<ErasedRequest>,
}
impl Drop for ErasedIoHandler {
fn drop(&mut self) { }
}
impl ErasedRequest {
pub fn new(
rocket: Arc<Rocket<Orbit>>,
parts: Parts,
constructor: impl for<'r> FnOnce(
&'r Rocket<Orbit>,
&'r Parts
) -> Request<'r>,
) -> ErasedRequest {
let rocket: Arc<Rocket<Orbit>> = rocket;
let parts: Box<Parts> = Box::new(parts);
let request: Request<'_> = {
let rocket: &Rocket<Orbit> = &*rocket;
let rocket: &'static Rocket<Orbit> = unsafe { transmute(rocket) };
let parts: &Parts = &*parts;
let parts: &'static Parts = unsafe { transmute(parts) };
constructor(&rocket, &parts)
};
ErasedRequest { _rocket: rocket, _parts: parts, request, }
}
pub async fn into_response<T: Send + Sync + 'static>(
self,
incoming: Incoming,
data_builder: impl for<'r> FnOnce(&'r mut Incoming) -> Data<'r>,
preprocess: impl for<'r, 'x> FnOnce(
&'r Rocket<Orbit>,
&'r mut Request<'x>,
&'r mut Data<'x>
) -> BoxFuture<'r, T>,
dispatch: impl for<'r> FnOnce(
T,
&'r Rocket<Orbit>,
&'r Request<'r>,
Data<'r>
) -> BoxFuture<'r, Response<'r>>,
) -> ErasedResponse {
let mut incoming = Box::new(incoming);
let mut data: Data<'_> = {
let incoming: &mut Incoming = &mut *incoming;
let incoming: &'static mut Incoming = unsafe { transmute(incoming) };
data_builder(incoming)
};
let mut parent = Arc::new(self);
let token: T = {
let parent: &mut ErasedRequest = Arc::get_mut(&mut parent).unwrap();
let rocket: &Rocket<Orbit> = &*parent._rocket;
let request: &mut Request<'_> = &mut parent.request;
let data: &mut Data<'_> = &mut data;
preprocess(rocket, request, data).await
};
let parent = parent;
let response: Response<'_> = {
let parent: &ErasedRequest = &*parent;
let parent: &'static ErasedRequest = unsafe { transmute(parent) };
let rocket: &Rocket<Orbit> = &*parent._rocket;
let request: &Request<'_> = &parent.request;
dispatch(token, rocket, request, data).await
};
ErasedResponse {
_request: parent,
_incoming: incoming,
response: response,
}
}
}
impl ErasedResponse {
pub fn inner<'a>(&'a self) -> &'a Response<'a> {
static_assert_covariance!(Response);
&self.response
}
pub fn with_inner_mut<'a, T>(
&'a mut self,
f: impl for<'r> FnOnce(&'a mut Response<'r>) -> T
) -> T {
static_assert_covariance!(Response);
f(&mut self.response)
}
pub fn to_io_handler<'a>(
&'a mut self,
constructor: impl for<'r> FnOnce(
&'r Request<'r>,
&'a mut Response<'r>,
) -> Option<Box<dyn IoHandler + 'r>>
) -> Option<ErasedIoHandler> {
let parent: Arc<ErasedRequest> = self._request.clone();
let io: Option<Box<dyn IoHandler + '_>> = {
let parent: &ErasedRequest = &*parent;
let parent: &'static ErasedRequest = unsafe { transmute(parent) };
let request: &Request<'_> = &parent.request;
constructor(request, &mut self.response)
};
io.map(|io| ErasedIoHandler { _request: parent, io })
}
}
impl ErasedIoHandler {
pub fn with_inner_mut<'a, T: 'a>(
&'a mut self,
f: impl for<'r> FnOnce(&'a mut Box<dyn IoHandler + 'r>) -> T
) -> T {
fn _assert_covariance<'x: 'y, 'y>(
x: &'y Box<dyn IoHandler + 'x>
) -> &'y Box<dyn IoHandler + 'y> { x }
f(&mut self.io)
}
pub fn take<'a>(&'a mut self) -> Box<dyn IoHandler + 'a> {
fn _assert_covariance<'x: 'y, 'y>(
x: &'y Box<dyn IoHandler + 'x>
) -> &'y Box<dyn IoHandler + 'y> { x }
self.with_inner_mut(|handler| std::mem::replace(handler, Box::new(())))
}
}
impl AsyncRead for ErasedResponse {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.get_mut().with_inner_mut(|r| Pin::new(r.body_mut()).poll_read(cx, buf))
}
}

View File

@ -74,11 +74,8 @@ pub struct Error {
#[derive(Debug)] #[derive(Debug)]
#[non_exhaustive] #[non_exhaustive]
pub enum ErrorKind { pub enum ErrorKind {
/// Binding to the provided address/port failed. /// Binding to the network interface failed.
Bind(io::Error), Bind(Box<dyn StdError + Send>),
/// Binding via TLS to the provided address/port failed.
#[cfg(feature = "tls")]
TlsBind(crate::http::tls::error::Error),
/// An I/O error occurred during launch. /// An I/O error occurred during launch.
Io(io::Error), Io(io::Error),
/// A valid [`Config`](crate::Config) could not be extracted from the /// A valid [`Config`](crate::Config) could not be extracted from the
@ -90,15 +87,10 @@ pub enum ErrorKind {
FailedFairings(Vec<crate::fairing::Info>), FailedFairings(Vec<crate::fairing::Info>),
/// Sentinels requested abort. /// Sentinels requested abort.
SentinelAborts(Vec<crate::sentinel::Sentry>), SentinelAborts(Vec<crate::sentinel::Sentry>),
/// The configuration profile is not debug but not secret key is configured. /// The configuration profile is not debug but no secret key is configured.
InsecureSecretKey(Profile), InsecureSecretKey(Profile),
/// Shutdown failed. /// Shutdown failed. Contains the Rocket instance that failed to shutdown.
Shutdown( Shutdown(Arc<Rocket<Orbit>>),
/// The instance of Rocket that failed to shutdown.
Arc<Rocket<Orbit>>,
/// The error that occurred during shutdown, if any.
Option<Box<dyn StdError + Send + Sync>>
),
} }
/// An error that occurs when a value was unexpectedly empty. /// An error that occurs when a value was unexpectedly empty.
@ -111,20 +103,24 @@ impl From<ErrorKind> for Error {
} }
} }
impl From<figment::Error> for Error {
fn from(e: figment::Error) -> Self {
Error::new(ErrorKind::Config(e))
}
}
impl From<io::Error> for Error {
fn from(e: io::Error) -> Self {
Error::new(ErrorKind::Io(e))
}
}
impl Error { impl Error {
#[inline(always)] #[inline(always)]
pub(crate) fn new(kind: ErrorKind) -> Error { pub(crate) fn new(kind: ErrorKind) -> Error {
Error { handled: AtomicBool::new(false), kind } Error { handled: AtomicBool::new(false), kind }
} }
#[inline(always)]
pub(crate) fn shutdown<E>(rocket: Arc<Rocket<Orbit>>, error: E) -> Error
where E: Into<Option<crate::http::hyper::Error>>
{
let error = error.into().map(|e| Box::new(e) as Box<dyn StdError + Sync + Send>);
Error::new(ErrorKind::Shutdown(rocket, error))
}
#[inline(always)] #[inline(always)]
fn was_handled(&self) -> bool { fn was_handled(&self) -> bool {
self.handled.load(Ordering::Acquire) self.handled.load(Ordering::Acquire)
@ -176,9 +172,9 @@ impl Error {
self.mark_handled(); self.mark_handled();
match self.kind() { match self.kind() {
ErrorKind::Bind(ref e) => { ErrorKind::Bind(ref e) => {
error!("Rocket failed to bind network socket to given address/port."); error!("Binding to the network interface failed.");
info_!("{}", e); info_!("{}", e);
"aborting due to socket bind error" "aborting due to bind error"
} }
ErrorKind::Io(ref e) => { ErrorKind::Io(ref e) => {
error!("Rocket failed to launch due to an I/O error."); error!("Rocket failed to launch due to an I/O error.");
@ -229,20 +225,10 @@ impl Error {
"aborting due to sentinel-triggered abort(s)" "aborting due to sentinel-triggered abort(s)"
} }
ErrorKind::Shutdown(_, error) => { ErrorKind::Shutdown(_) => {
error!("Rocket failed to shutdown gracefully."); error!("Rocket failed to shutdown gracefully.");
if let Some(e) = error {
info_!("{}", e);
}
"aborting due to failed shutdown" "aborting due to failed shutdown"
} }
#[cfg(feature = "tls")]
ErrorKind::TlsBind(e) => {
error!("Rocket failed to bind via TLS to network socket.");
info_!("{}", e);
"aborting due to TLS bind error"
}
} }
} }
} }
@ -260,10 +246,7 @@ impl fmt::Display for ErrorKind {
ErrorKind::InsecureSecretKey(_) => "insecure secret key config".fmt(f), ErrorKind::InsecureSecretKey(_) => "insecure secret key config".fmt(f),
ErrorKind::Config(_) => "failed to extract configuration".fmt(f), ErrorKind::Config(_) => "failed to extract configuration".fmt(f),
ErrorKind::SentinelAborts(_) => "sentinel(s) aborted".fmt(f), ErrorKind::SentinelAborts(_) => "sentinel(s) aborted".fmt(f),
ErrorKind::Shutdown(_, Some(e)) => write!(f, "shutdown failed: {e}"), ErrorKind::Shutdown(_) => "shutdown failed".fmt(f),
ErrorKind::Shutdown(_, None) => "shutdown failed".fmt(f),
#[cfg(feature = "tls")]
ErrorKind::TlsBind(e) => write!(f, "TLS bind failed: {e}"),
} }
} }
} }
@ -308,3 +291,42 @@ impl fmt::Display for Empty {
} }
impl StdError for Empty { } impl StdError for Empty { }
/// Log an error that occurs during request processing
pub(crate) fn log_server_error(error: &Box<dyn StdError + Send + Sync>) {
struct ServerError<'a>(&'a (dyn StdError + 'static));
impl fmt::Display for ServerError<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let error = &self.0;
if let Some(e) = error.downcast_ref::<hyper::Error>() {
write!(f, "request processing failed: {e}")?;
} else if let Some(e) = error.downcast_ref::<io::Error>() {
write!(f, "connection I/O error: ")?;
match e.kind() {
io::ErrorKind::NotConnected => write!(f, "remote disconnected")?,
io::ErrorKind::UnexpectedEof => write!(f, "remote sent early eof")?,
io::ErrorKind::ConnectionReset
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::BrokenPipe => write!(f, "terminated by remote")?,
_ => write!(f, "{e}")?,
}
} else {
write!(f, "http server error: {error}")?;
}
if let Some(e) = error.source() {
write!(f, " ({})", ServerError(e))?;
}
Ok(())
}
}
if error.downcast_ref::<hyper::Error>().is_some() {
warn!("{}", ServerError(&**error))
} else {
error!("{}", ServerError(&**error))
}
}

View File

@ -1,404 +0,0 @@
use std::{io, time::Duration};
use std::task::{Poll, Context};
use std::pin::Pin;
use bytes::{Bytes, BytesMut};
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::time::{sleep, Sleep};
use futures::stream::Stream;
use futures::future::{self, Future, FutureExt};
pin_project! {
pub struct ReaderStream<R> {
#[pin]
reader: Option<R>,
buf: BytesMut,
cap: usize,
}
}
impl<R: AsyncRead> Stream for ReaderStream<R> {
type Item = std::io::Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use tokio_util::io::poll_read_buf;
let mut this = self.as_mut().project();
let reader = match this.reader.as_pin_mut() {
Some(r) => r,
None => return Poll::Ready(None),
};
if this.buf.capacity() == 0 {
this.buf.reserve(*this.cap);
}
match poll_read_buf(reader, cx, &mut this.buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(err)) => {
self.project().reader.set(None);
Poll::Ready(Some(Err(err)))
}
Poll::Ready(Ok(0)) => {
self.project().reader.set(None);
Poll::Ready(None)
}
Poll::Ready(Ok(_)) => {
let chunk = this.buf.split();
Poll::Ready(Some(Ok(chunk.freeze())))
}
}
}
}
pub trait AsyncReadExt: AsyncRead + Sized {
fn into_bytes_stream(self, cap: usize) -> ReaderStream<Self> {
ReaderStream { reader: Some(self), cap, buf: BytesMut::with_capacity(cap) }
}
}
impl<T: AsyncRead> AsyncReadExt for T { }
pub trait PollExt<T, E> {
fn map_err_ext<U, F>(self, f: F) -> Poll<Option<Result<T, U>>>
where F: FnOnce(E) -> U;
}
impl<T, E> PollExt<T, E> for Poll<Option<Result<T, E>>> {
/// Changes the error value of this `Poll` with the closure provided.
fn map_err_ext<U, F>(self, f: F) -> Poll<Option<Result<T, U>>>
where F: FnOnce(E) -> U
{
match self {
Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(Ok(t))),
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(f(e)))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
pin_project! {
/// Stream for the [`chain`](super::AsyncReadExt::chain) method.
#[must_use = "streams do nothing unless polled"]
pub struct Chain<T, U> {
#[pin]
first: T,
#[pin]
second: U,
done_first: bool,
}
}
impl<T: AsyncRead, U: AsyncRead> Chain<T, U> {
pub(crate) fn new(first: T, second: U) -> Self {
Self { first, second, done_first: false }
}
}
impl<T: AsyncRead, U: AsyncRead> Chain<T, U> {
/// Gets references to the underlying readers in this `Chain`.
pub fn get_ref(&self) -> (&T, &U) {
(&self.first, &self.second)
}
}
impl<T: AsyncRead, U: AsyncRead> AsyncRead for Chain<T, U> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let me = self.project();
if !*me.done_first {
let init_rem = buf.remaining();
futures::ready!(me.first.poll_read(cx, buf))?;
if buf.remaining() == init_rem {
*me.done_first = true;
} else {
return Poll::Ready(Ok(()));
}
}
me.second.poll_read(cx, buf)
}
}
enum State {
/// I/O has not been cancelled. Proceed as normal.
Active,
/// I/O has been cancelled. See if we can finish before the timer expires.
Grace(Pin<Box<Sleep>>),
/// Grace period elapsed. Shutdown the connection, waiting for the timer
/// until we force close.
Mercy(Pin<Box<Sleep>>),
}
pin_project! {
/// I/O that can be cancelled when a future `F` resolves.
#[must_use = "futures do nothing unless polled"]
pub struct CancellableIo<F, I> {
#[pin]
io: Option<I>,
#[pin]
trigger: future::Fuse<F>,
state: State,
grace: Duration,
mercy: Duration,
}
}
impl<F: Future, I: AsyncWrite> CancellableIo<F, I> {
pub fn new(trigger: F, io: I, grace: Duration, mercy: Duration) -> Self {
CancellableIo {
grace, mercy,
io: Some(io),
trigger: trigger.fuse(),
state: State::Active,
}
}
pub fn io(&self) -> Option<&I> {
self.io.as_ref()
}
/// Run `do_io` while connection processing should continue.
fn poll_trigger_then<T>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll<io::Result<T>>,
) -> Poll<io::Result<T>> {
let mut me = self.as_mut().project();
let io = match me.io.as_pin_mut() {
Some(io) => io,
None => return Poll::Ready(Err(gone())),
};
loop {
match me.state {
State::Active => {
if me.trigger.as_mut().poll(cx).is_ready() {
*me.state = State::Grace(Box::pin(sleep(*me.grace)));
} else {
return do_io(io, cx);
}
}
State::Grace(timer) => {
if timer.as_mut().poll(cx).is_ready() {
*me.state = State::Mercy(Box::pin(sleep(*me.mercy)));
} else {
return do_io(io, cx);
}
}
State::Mercy(timer) => {
if timer.as_mut().poll(cx).is_ready() {
self.project().io.set(None);
return Poll::Ready(Err(time_out()));
} else {
let result = futures::ready!(io.poll_shutdown(cx));
self.project().io.set(None);
return match result {
Err(e) => Poll::Ready(Err(e)),
Ok(()) => Poll::Ready(Err(gone()))
};
}
},
}
}
}
}
fn time_out() -> io::Error {
io::Error::new(io::ErrorKind::TimedOut, "Shutdown grace timed out")
}
fn gone() -> io::Error {
io::Error::new(io::ErrorKind::BrokenPipe, "IO driver has terminated")
}
impl<F: Future, I: AsyncRead + AsyncWrite> AsyncRead for CancellableIo<F, I> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_read(cx, buf))
}
}
impl<F: Future, I: AsyncWrite> AsyncWrite for CancellableIo<F, I> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write(cx, buf))
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<()>> {
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_flush(cx))
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<()>> {
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_shutdown(cx))
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write_vectored(cx, bufs))
}
fn is_write_vectored(&self) -> bool {
self.io().map(|io| io.is_write_vectored()).unwrap_or(false)
}
}
use crate::http::private::{Listener, Connection, Certificates};
impl<F: Future, C: Connection> Connection for CancellableIo<F, C> {
fn peer_address(&self) -> Option<std::net::SocketAddr> {
self.io().and_then(|io| io.peer_address())
}
fn peer_certificates(&self) -> Option<Certificates> {
self.io().and_then(|io| io.peer_certificates())
}
fn enable_nodelay(&self) -> io::Result<()> {
match self.io() {
Some(io) => io.enable_nodelay(),
None => Ok(())
}
}
}
pin_project! {
pub struct CancellableListener<F, L> {
pub trigger: F,
#[pin]
pub listener: L,
pub grace: Duration,
pub mercy: Duration,
}
}
impl<F, L> CancellableListener<F, L> {
pub fn new(trigger: F, listener: L, grace: u64, mercy: u64) -> Self {
let (grace, mercy) = (Duration::from_secs(grace), Duration::from_secs(mercy));
CancellableListener { trigger, listener, grace, mercy }
}
}
impl<L: Listener, F: Future + Clone> Listener for CancellableListener<F, L> {
type Connection = CancellableIo<F, L::Connection>;
fn local_addr(&self) -> Option<std::net::SocketAddr> {
self.listener.local_addr()
}
fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<Self::Connection>> {
self.as_mut().project().listener
.poll_accept(cx)
.map(|res| res.map(|conn| {
CancellableIo::new(self.trigger.clone(), conn, self.grace, self.mercy)
}))
}
}
pub trait StreamExt: Sized + Stream {
fn join<U>(self, other: U) -> Join<Self, U>
where U: Stream<Item = Self::Item>;
}
impl<S: Stream> StreamExt for S {
fn join<U>(self, other: U) -> Join<Self, U>
where U: Stream<Item = Self::Item>
{
Join::new(self, other)
}
}
pin_project! {
/// Stream returned by the [`join`](super::StreamExt::join) method.
pub struct Join<T, U> {
#[pin]
a: T,
#[pin]
b: U,
// When `true`, poll `a` first, otherwise, `poll` b`.
toggle: bool,
// Set when either `a` or `b` return `None`.
done: bool,
}
}
impl<T, U> Join<T, U> {
pub(super) fn new(a: T, b: U) -> Join<T, U>
where T: Stream, U: Stream,
{
Join { a, b, toggle: false, done: false, }
}
fn poll_next<A: Stream, B: Stream<Item = A::Item>>(
first: Pin<&mut A>,
second: Pin<&mut B>,
done: &mut bool,
cx: &mut Context<'_>,
) -> Poll<Option<A::Item>> {
match first.poll_next(cx) {
Poll::Ready(opt) => { *done = opt.is_none(); Poll::Ready(opt) }
Poll::Pending => match second.poll_next(cx) {
Poll::Ready(opt) => { *done = opt.is_none(); Poll::Ready(opt) }
Poll::Pending => Poll::Pending
}
}
}
}
impl<T, U> Stream for Join<T, U>
where T: Stream,
U: Stream<Item = T::Item>,
{
type Item = T::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T::Item>> {
if self.done {
return Poll::Ready(None);
}
let me = self.project();
*me.toggle = !*me.toggle;
match *me.toggle {
true => Self::poll_next(me.a, me.b, me.done, cx),
false => Self::poll_next(me.b, me.a, me.done, cx),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let (left_low, left_high) = self.a.size_hint();
let (right_low, right_high) = self.b.size_hint();
let low = left_low.saturating_add(right_low);
let high = match (left_high, right_high) {
(Some(h1), Some(h2)) => h1.checked_add(h2),
_ => None,
};
(low, high)
}
}

View File

@ -341,6 +341,7 @@
// `key_contexts: Vec<K::Context>`, a vector of `value_contexts: // `key_contexts: Vec<K::Context>`, a vector of `value_contexts:
// Vec<V::Context>`, a `mapping` from a string index to an integer index // Vec<V::Context>`, a `mapping` from a string index to an integer index
// into the `contexts`, and a vector of `errors`. // into the `contexts`, and a vector of `errors`.
//
// 2. **Push.** An index is required; an error is emitted and `push` returns // 2. **Push.** An index is required; an error is emitted and `push` returns
// if they field's first key does not contain an index. If the first key // if they field's first key does not contain an index. If the first key
// contains _one_ index, a new `K::Context` and `V::Context` are created. // contains _one_ index, a new `K::Context` and `V::Context` are created.
@ -356,9 +357,9 @@
// to `second` in `mapping`. If the first index is `k`, the field, // to `second` in `mapping`. If the first index is `k`, the field,
// stripped of the first key, is pushed to the key's context; the same is // stripped of the first key, is pushed to the key's context; the same is
// done for the value's context is the first index is `v`. // done for the value's context is the first index is `v`.
//
// 3. **Finalization.** Every context is finalized; errors and `Ok` values // 3. **Finalization.** Every context is finalized; errors and `Ok` values
// are collected. TODO: FINISH. Split this into two: one for single-index, // are collected.
// another for two-indices.
mod field; mod field;
mod options; mod options;

View File

@ -2,7 +2,7 @@ use std::io;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use tokio::fs::File; use tokio::fs::{File, OpenOptions};
use crate::request::Request; use crate::request::Request;
use crate::response::{self, Responder}; use crate::response::{self, Responder};
@ -60,7 +60,7 @@ impl NamedFile {
/// } /// }
/// ``` /// ```
pub async fn open<P: AsRef<Path>>(path: P) -> io::Result<NamedFile> { pub async fn open<P: AsRef<Path>>(path: P) -> io::Result<NamedFile> {
// FIXME: Grab the file size here and prohibit `seek`ing later (or else // TODO: Grab the file size here and prohibit `seek`ing later (or else
// the file's effective size may change), to save on the cost of doing // the file's effective size may change), to save on the cost of doing
// all of those `seek`s to determine the file size. But, what happens if // all of those `seek`s to determine the file size. But, what happens if
// the file gets changed between now and then? // the file gets changed between now and then?
@ -68,6 +68,11 @@ impl NamedFile {
Ok(NamedFile(path.as_ref().to_path_buf(), file)) Ok(NamedFile(path.as_ref().to_path_buf(), file))
} }
pub async fn open_with<P: AsRef<Path>>(path: P, opts: &OpenOptions) -> io::Result<NamedFile> {
let file = opts.open(path.as_ref()).await?;
Ok(NamedFile(path.as_ref().to_path_buf(), file))
}
/// Retrieve the underlying `File`. /// Retrieve the underlying `File`.
/// ///
/// # Example /// # Example

View File

@ -2,11 +2,10 @@ use std::fmt;
use parking_lot::Mutex; use parking_lot::Mutex;
use crate::http::private::cookie;
use crate::{Rocket, Orbit}; use crate::{Rocket, Orbit};
#[doc(inline)] #[doc(inline)]
pub use self::cookie::{Cookie, SameSite, Iter}; pub use cookie::{Cookie, SameSite, Iter};
/// Collection of one or more HTTP cookies. /// Collection of one or more HTTP cookies.
/// ///
@ -167,7 +166,7 @@ pub(crate) struct CookieState<'a> {
#[derive(Clone)] #[derive(Clone)]
enum Op { enum Op {
Add(Cookie<'static>, bool), Add(Cookie<'static>, bool),
Remove(Cookie<'static>, bool), Remove(Cookie<'static>),
} }
impl<'a> CookieJar<'a> { impl<'a> CookieJar<'a> {
@ -177,7 +176,7 @@ impl<'a> CookieJar<'a> {
ops: Mutex::new(Vec::new()), ops: Mutex::new(Vec::new()),
state: CookieState { state: CookieState {
// This is updated dynamically when headers are received. // This is updated dynamically when headers are received.
secure: rocket.config().tls_enabled(), secure: rocket.endpoint().is_tls(),
config: rocket.config(), config: rocket.config(),
} }
} }
@ -256,7 +255,7 @@ impl<'a> CookieJar<'a> {
for op in ops.iter().rev().filter(|op| op.cookie().name() == name) { for op in ops.iter().rev().filter(|op| op.cookie().name() == name) {
match op { match op {
Op::Add(c, _) => return Some(c.clone()), Op::Add(c, _) => return Some(c.clone()),
Op::Remove(_, _) => return None, Op::Remove(_) => return None,
} }
} }
@ -389,7 +388,7 @@ impl<'a> CookieJar<'a> {
pub fn remove<C: Into<Cookie<'static>>>(&self, cookie: C) { pub fn remove<C: Into<Cookie<'static>>>(&self, cookie: C) {
let mut cookie = cookie.into(); let mut cookie = cookie.into();
Self::set_removal_defaults(&mut cookie); Self::set_removal_defaults(&mut cookie);
self.ops.lock().push(Op::Remove(cookie, false)); self.ops.lock().push(Op::Remove(cookie));
} }
/// Removes the private `cookie` from the collection. /// Removes the private `cookie` from the collection.
@ -432,7 +431,7 @@ impl<'a> CookieJar<'a> {
pub fn remove_private<C: Into<Cookie<'static>>>(&self, cookie: C) { pub fn remove_private<C: Into<Cookie<'static>>>(&self, cookie: C) {
let mut cookie = cookie.into(); let mut cookie = cookie.into();
Self::set_removal_defaults(&mut cookie); Self::set_removal_defaults(&mut cookie);
self.ops.lock().push(Op::Remove(cookie, true)); self.ops.lock().push(Op::Remove(cookie));
} }
/// Returns an iterator over all of the _original_ cookies present in this /// Returns an iterator over all of the _original_ cookies present in this
@ -477,7 +476,7 @@ impl<'a> CookieJar<'a> {
Op::Add(c, true) => { Op::Add(c, true) => {
jar.private_mut(&self.state.config.secret_key.key).add(c); jar.private_mut(&self.state.config.secret_key.key).add(c);
} }
Op::Remove(mut c, _) => { Op::Remove(mut c) => {
if self.jar.get(c.name()).is_some() { if self.jar.get(c.name()).is_some() {
c.make_removal(); c.make_removal();
jar.add(c); jar.add(c);
@ -595,7 +594,7 @@ impl<'a> Clone for CookieJar<'a> {
impl Op { impl Op {
fn cookie(&self) -> &Cookie<'static> { fn cookie(&self) -> &Cookie<'static> {
match self { match self {
Op::Add(c, _) | Op::Remove(c, _) => c Op::Add(c, _) | Op::Remove(c) => c
} }
} }
} }

12
core/lib/src/http/mod.rs Normal file
View File

@ -0,0 +1,12 @@
//! Types that map to concepts in HTTP.
//!
//! This module exports types that map to HTTP concepts or to the underlying
//! HTTP library when needed.
mod cookies;
#[doc(inline)]
pub use rocket_http::*;
#[doc(inline)]
pub use cookies::*;

View File

@ -7,7 +7,9 @@
#![cfg_attr(nightly, feature(decl_macro))] #![cfg_attr(nightly, feature(decl_macro))]
#![warn(rust_2018_idioms)] #![warn(rust_2018_idioms)]
#![warn(missing_docs)] // #![warn(missing_docs)]
#![allow(async_fn_in_trait)]
#![allow(refining_impl_trait)]
//! # Rocket - Core API Documentation //! # Rocket - Core API Documentation
//! //!
@ -109,18 +111,24 @@
/// These are public dependencies! Update docs if these are changed, especially /// These are public dependencies! Update docs if these are changed, especially
/// figment's version number in docs. /// figment's version number in docs.
#[doc(hidden)] pub use yansi; #[doc(hidden)]
#[doc(hidden)] pub use async_stream; pub use yansi;
#[doc(hidden)]
pub use async_stream;
pub use futures; pub use futures;
pub use tokio; pub use tokio;
pub use figment; pub use figment;
pub use time; pub use time;
#[doc(hidden)] #[doc(hidden)]
#[macro_use] pub mod log; #[macro_use]
#[macro_use] pub mod outcome; pub mod log;
#[macro_use] pub mod data; #[macro_use]
#[doc(hidden)] pub mod sentinel; pub mod outcome;
#[macro_use]
pub mod data;
#[doc(hidden)]
pub mod sentinel;
pub mod local; pub mod local;
pub mod request; pub mod request;
pub mod response; pub mod response;
@ -133,74 +141,41 @@ pub mod route;
pub mod serde; pub mod serde;
pub mod shield; pub mod shield;
pub mod fs; pub mod fs;
pub mod http;
// Reexport of HTTP everything. pub mod listener;
pub mod http { #[cfg(feature = "tls")]
//! Types that map to concepts in HTTP. #[cfg_attr(nightly, doc(cfg(feature = "tls")))]
//! pub mod tls;
//! This module exports types that map to HTTP concepts or to the underlying
//! HTTP library when needed.
#[doc(inline)]
pub use rocket_http::*;
/// Re-exported hyper HTTP library types.
///
/// All types that are re-exported from Hyper reside inside of this module.
/// These types will, with certainty, be removed with time, but they reside here
/// while necessary.
pub mod hyper {
#[doc(hidden)]
pub use rocket_http::hyper::*;
pub use rocket_http::hyper::header;
}
#[doc(inline)]
pub use crate::cookies::*;
}
#[cfg(feature = "mtls")] #[cfg(feature = "mtls")]
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))] #[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
pub mod mtls; pub mod mtls;
/// TODO: We need a futures mod or something. mod util;
mod trip_wire;
mod shutdown; mod shutdown;
mod server; mod server;
mod ext; mod lifecycle;
mod state; mod state;
mod cookies;
mod rocket; mod rocket;
mod router; mod router;
mod phase; mod phase;
mod erased;
#[doc(hidden)] pub use either::Either;
#[doc(inline)] pub use rocket_codegen::*;
#[doc(inline)] pub use crate::response::Response; #[doc(inline)] pub use crate::response::Response;
#[doc(inline)] pub use crate::data::Data; #[doc(inline)] pub use crate::data::Data;
#[doc(inline)] pub use crate::config::Config; #[doc(inline)] pub use crate::config::Config;
#[doc(inline)] pub use crate::catcher::Catcher; #[doc(inline)] pub use crate::catcher::Catcher;
#[doc(inline)] pub use crate::route::Route; #[doc(inline)] pub use crate::route::Route;
#[doc(hidden)] pub use either::Either; #[doc(inline)] pub use crate::phase::{Phase, Build, Ignite, Orbit};
#[doc(inline)] pub use phase::{Phase, Build, Ignite, Orbit}; #[doc(inline)] pub use crate::error::Error;
#[doc(inline)] pub use error::Error; #[doc(inline)] pub use crate::sentinel::Sentinel;
#[doc(inline)] pub use sentinel::Sentinel;
#[doc(inline)] pub use crate::request::Request; #[doc(inline)] pub use crate::request::Request;
#[doc(inline)] pub use crate::rocket::Rocket; #[doc(inline)] pub use crate::rocket::Rocket;
#[doc(inline)] pub use crate::shutdown::Shutdown; #[doc(inline)] pub use crate::shutdown::Shutdown;
#[doc(inline)] pub use crate::state::State; #[doc(inline)] pub use crate::state::State;
#[doc(inline)] pub use rocket_codegen::*;
/// Creates a [`Rocket`] instance with the default config provider: aliases
/// [`Rocket::build()`].
pub fn build() -> Rocket<Build> {
Rocket::build()
}
/// Creates a [`Rocket`] instance with a custom config provider: aliases
/// [`Rocket::custom()`].
pub fn custom<T: figment::Provider>(provider: T) -> Rocket<Build> {
Rocket::custom(provider)
}
/// Retrofits support for `async fn` in trait impls and declarations. /// Retrofits support for `async fn` in trait impls and declarations.
/// ///
@ -231,6 +206,20 @@ pub fn custom<T: figment::Provider>(provider: T) -> Rocket<Build> {
#[doc(inline)] #[doc(inline)]
pub use async_trait::async_trait; pub use async_trait::async_trait;
const WORKER_PREFIX: &'static str = "rocket-worker";
/// Creates a [`Rocket`] instance with the default config provider: aliases
/// [`Rocket::build()`].
pub fn build() -> Rocket<Build> {
Rocket::build()
}
/// Creates a [`Rocket`] instance with a custom config provider: aliases
/// [`Rocket::custom()`].
pub fn custom<T: figment::Provider>(provider: T) -> Rocket<Build> {
Rocket::custom(provider)
}
/// WARNING: This is unstable! Do not use this method outside of Rocket! /// WARNING: This is unstable! Do not use this method outside of Rocket!
#[doc(hidden)] #[doc(hidden)]
pub fn async_run<F, R>(fut: F, workers: usize, sync: usize, force_end: bool, name: &str) -> R pub fn async_run<F, R>(fut: F, workers: usize, sync: usize, force_end: bool, name: &str) -> R
@ -255,7 +244,7 @@ pub fn async_run<F, R>(fut: F, workers: usize, sync: usize, force_end: bool, nam
/// WARNING: This is unstable! Do not use this method outside of Rocket! /// WARNING: This is unstable! Do not use this method outside of Rocket!
#[doc(hidden)] #[doc(hidden)]
pub fn async_test<R>(fut: impl std::future::Future<Output = R>) -> R { pub fn async_test<R>(fut: impl std::future::Future<Output = R>) -> R {
async_run(fut, 1, 32, true, "rocket-worker-test-thread") async_run(fut, 1, 32, true, &format!("{WORKER_PREFIX}-test-thread"))
} }
/// WARNING: This is unstable! Do not use this method outside of Rocket! /// WARNING: This is unstable! Do not use this method outside of Rocket!
@ -276,7 +265,7 @@ pub fn async_main<R>(fut: impl std::future::Future<Output = R> + Send) -> R {
let workers = fig.extract_inner(Config::WORKERS).unwrap_or_else(bail); let workers = fig.extract_inner(Config::WORKERS).unwrap_or_else(bail);
let max_blocking = fig.extract_inner(Config::MAX_BLOCKING).unwrap_or_else(bail); let max_blocking = fig.extract_inner(Config::MAX_BLOCKING).unwrap_or_else(bail);
let force = fig.focus(Config::SHUTDOWN).extract_inner("force").unwrap_or_else(bail); let force = fig.focus(Config::SHUTDOWN).extract_inner("force").unwrap_or_else(bail);
async_run(fut, workers, max_blocking, force, "rocket-worker-thread") async_run(fut, workers, max_blocking, force, &format!("{WORKER_PREFIX}-thread"))
} }
/// Executes a `future` to completion on a new tokio-based Rocket async runtime. /// Executes a `future` to completion on a new tokio-based Rocket async runtime.
@ -359,3 +348,14 @@ pub fn execute<R, F>(future: F) -> R
{ {
async_main(future) async_main(future)
} }
/// Returns a future that evalutes to `true` exactly when there is a presently
/// running tokio async runtime that was likely started by Rocket.
fn running_within_rocket_async_rt() -> impl std::future::Future<Output = bool> {
use futures::FutureExt;
tokio::task::spawn_blocking(|| {
let this = std::thread::current();
this.name().map_or(false, |s| s.starts_with(WORKER_PREFIX))
}).map(|r| r.unwrap_or(false))
}

272
core/lib/src/lifecycle.rs Normal file
View File

@ -0,0 +1,272 @@
use yansi::Paint;
use futures::future::{FutureExt, Future};
use crate::{route, Rocket, Orbit, Request, Response, Data};
use crate::data::IoHandler;
use crate::http::{Method, Status, Header};
use crate::outcome::Outcome;
use crate::form::Form;
// A token returned to force the execution of one method before another.
pub(crate) struct RequestToken;
async fn catch_handle<Fut, T, F>(name: Option<&str>, run: F) -> Option<T>
where F: FnOnce() -> Fut, Fut: Future<Output = T>,
{
macro_rules! panic_info {
($name:expr, $e:expr) => {{
match $name {
Some(name) => error_!("Handler {} panicked.", name.primary()),
None => error_!("A handler panicked.")
};
info_!("This is an application bug.");
info_!("A panic in Rust must be treated as an exceptional event.");
info_!("Panicking is not a suitable error handling mechanism.");
info_!("Unwinding, the result of a panic, is an expensive operation.");
info_!("Panics will degrade application performance.");
info_!("Instead of panicking, return `Option` and/or `Result`.");
info_!("Values of either type can be returned directly from handlers.");
warn_!("A panic is treated as an internal server error.");
$e
}}
}
let run = std::panic::AssertUnwindSafe(run);
let fut = std::panic::catch_unwind(move || run())
.map_err(|e| panic_info!(name, e))
.ok()?;
std::panic::AssertUnwindSafe(fut)
.catch_unwind()
.await
.map_err(|e| panic_info!(name, e))
.ok()
}
impl Rocket<Orbit> {
/// Preprocess the request for Rocket things. Currently, this means:
///
/// * Rewriting the method in the request if _method form field exists.
/// * Run the request fairings.
///
/// This is the only place during lifecycle processing that `Request` is
/// mutable. Keep this in-sync with the `FromForm` derive.
pub(crate) async fn preprocess(
&self,
req: &mut Request<'_>,
data: &mut Data<'_>
) -> RequestToken {
// Check if this is a form and if the form contains the special _method
// field which we use to reinterpret the request's method.
let (min_len, max_len) = ("_method=get".len(), "_method=delete".len());
let peek_buffer = data.peek(max_len).await;
let is_form = req.content_type().map_or(false, |ct| ct.is_form());
if is_form && req.method() == Method::Post && peek_buffer.len() >= min_len {
let method = std::str::from_utf8(peek_buffer).ok()
.and_then(|raw_form| Form::values(raw_form).next())
.filter(|field| field.name == "_method")
.and_then(|field| field.value.parse().ok());
if let Some(method) = method {
req.set_method(method);
}
}
// Run request fairings.
self.fairings.handle_request(req, data).await;
RequestToken
}
/// Dispatches the request to the router and processes the outcome to
/// produce a response. If the initial outcome is a *forward* and the
/// request was a HEAD request, the request is rewritten and rerouted as a
/// GET. This is automatic HEAD handling.
///
/// After performing the above, if the outcome is a forward or error, the
/// appropriate error catcher is invoked to produce the response. Otherwise,
/// the successful response is used directly.
///
/// Finally, new cookies in the cookie jar are added to the response,
/// Rocket-specific headers are written, and response fairings are run. Note
/// that error responses have special cookie handling. See `handle_error`.
pub(crate) async fn dispatch<'r, 's: 'r>(
&'s self,
_token: RequestToken,
request: &'r Request<'s>,
data: Data<'r>,
// io_stream: impl Future<Output = io::Result<IoStream>> + Send,
) -> Response<'r> {
info!("{}:", request);
// Remember if the request is `HEAD` for later body stripping.
let was_head_request = request.method() == Method::Head;
// Route the request and run the user's handlers.
let mut response = match self.route(request, data).await {
Outcome::Success(response) => response,
Outcome::Forward((data, _)) if request.method() == Method::Head => {
info_!("Autohandling {} request.", "HEAD".primary().bold());
// Dispatch the request again with Method `GET`.
request._set_method(Method::Get);
match self.route(request, data).await {
Outcome::Success(response) => response,
Outcome::Error(status) => self.dispatch_error(status, request).await,
Outcome::Forward((_, status)) => self.dispatch_error(status, request).await,
}
}
Outcome::Forward((_, status)) => self.dispatch_error(status, request).await,
Outcome::Error(status) => self.dispatch_error(status, request).await,
};
// Set the cookies. Note that error responses will only include cookies
// set by the error handler. See `handle_error` for more.
let delta_jar = request.cookies().take_delta_jar();
for cookie in delta_jar.delta() {
response.adjoin_header(cookie);
}
// Add a default 'Server' header if it isn't already there.
// TODO: If removing Hyper, write out `Date` header too.
if let Some(ident) = request.rocket().config.ident.as_str() {
if !response.headers().contains("Server") {
response.set_header(Header::new("Server", ident));
}
}
// Run the response fairings.
self.fairings.handle_response(request, &mut response).await;
// Strip the body if this is a `HEAD` request.
if was_head_request {
response.strip_body();
}
// TODO: Should upgrades be handled here? We miss them on local clients.
response
}
pub(crate) fn extract_io_handler<'r>(
request: &Request<'_>,
response: &mut Response<'r>,
// io_stream: impl Future<Output = io::Result<IoStream>> + Send,
) -> Option<Box<dyn IoHandler + 'r>> {
let upgrades = request.headers().get("upgrade");
let Ok(upgrade) = response.search_upgrades(upgrades) else {
warn_!("Request wants upgrade but no I/O handler matched.");
info_!("Request is not being upgraded.");
return None;
};
if let Some((proto, io_handler)) = upgrade {
info_!("Attemping upgrade with {proto} I/O handler.");
response.set_status(Status::SwitchingProtocols);
response.set_raw_header("Connection", "Upgrade");
response.set_raw_header("Upgrade", proto.to_string());
return Some(io_handler);
}
None
}
/// Calls the handler for each matching route until one of the handlers
/// returns success or error, or there are no additional routes to try, in
/// which case a `Forward` with the last forwarding state is returned.
#[inline]
async fn route<'s, 'r: 's>(
&'s self,
request: &'r Request<'s>,
mut data: Data<'r>,
) -> route::Outcome<'r> {
// Go through all matching routes until we fail or succeed or run out of
// routes to try, in which case we forward with the last status.
let mut status = Status::NotFound;
for route in self.router.route(request) {
// Retrieve and set the requests parameters.
info_!("Matched: {}", route);
request.set_route(route);
let name = route.name.as_deref();
let outcome = catch_handle(name, || route.handler.handle(request, data)).await
.unwrap_or(Outcome::Error(Status::InternalServerError));
// Check if the request processing completed (Some) or if the
// request needs to be forwarded. If it does, continue the loop
// (None) to try again.
info_!("{}", outcome.log_display());
match outcome {
o@Outcome::Success(_) | o@Outcome::Error(_) => return o,
Outcome::Forward(forwarded) => (data, status) = forwarded,
}
}
error_!("No matching routes for {}.", request);
Outcome::Forward((data, status))
}
// Invokes the catcher for `status`. Returns the response on success.
//
// Resets the cookie jar delta state to prevent any modifications from
// earlier unsuccessful paths from being reflected in the error response.
//
// On catcher error, the 500 error catcher is attempted. If _that_ errors,
// the (infallible) default 500 error cather is used.
pub(crate) async fn dispatch_error<'r, 's: 'r>(
&'s self,
mut status: Status,
req: &'r Request<'s>
) -> Response<'r> {
// We may wish to relax this in the future.
req.cookies().reset_delta();
// Dispatch to the `status` catcher.
if let Ok(r) = self.invoke_catcher(status, req).await {
return r;
}
// If it fails and it's not a 500, try the 500 catcher.
if status != Status::InternalServerError {
error_!("Catcher failed. Attempting 500 error catcher.");
status = Status::InternalServerError;
if let Ok(r) = self.invoke_catcher(status, req).await {
return r;
}
}
// If it failed again or if it was already a 500, use Rocket's default.
error_!("{} catcher failed. Using Rocket default 500.", status.code);
crate::catcher::default_handler(Status::InternalServerError, req)
}
/// Invokes the handler with `req` for catcher with status `status`.
///
/// In order of preference, invoked handler is:
/// * the user's registered handler for `status`
/// * the user's registered `default` handler
/// * Rocket's default handler for `status`
///
/// Return `Ok(result)` if the handler succeeded. Returns `Ok(Some(Status))`
/// if the handler ran to completion but failed. Returns `Ok(None)` if the
/// handler panicked while executing.
async fn invoke_catcher<'s, 'r: 's>(
&'s self,
status: Status,
req: &'r Request<'s>
) -> Result<Response<'r>, Option<Status>> {
if let Some(catcher) = self.router.catch(status, req) {
warn_!("Responding with registered {} catcher.", catcher);
let name = catcher.name.as_deref();
catch_handle(name, || catcher.handler.handle(status, req)).await
.map(|result| result.map_err(Some))
.unwrap_or_else(|| Err(None))
} else {
let code = status.code.blue().bold();
warn_!("No {} catcher registered. Using Rocket default.", code);
Ok(crate::catcher::default_handler(status, req))
}
}
}

View File

@ -0,0 +1,40 @@
use futures::TryFutureExt;
use crate::listener::Listener;
pub trait Bindable: Sized {
type Listener: Listener + 'static;
type Error: std::error::Error + Send + 'static;
async fn bind(self) -> Result<Self::Listener, Self::Error>;
}
impl<L: Listener + 'static> Bindable for L {
type Listener = L;
type Error = std::convert::Infallible;
async fn bind(self) -> Result<Self::Listener, Self::Error> {
Ok(self)
}
}
impl<A: Bindable, B: Bindable> Bindable for either::Either<A, B> {
type Listener = tokio_util::either::Either<A::Listener, B::Listener>;
type Error = either::Either<A::Error, B::Error>;
async fn bind(self) -> Result<Self::Listener, Self::Error> {
match self {
either::Either::Left(a) => a.bind()
.map_ok(tokio_util::either::Either::Left)
.map_err(either::Either::Left)
.await,
either::Either::Right(b) => b.bind()
.map_ok(tokio_util::either::Either::Right)
.map_err(either::Either::Right)
.await,
}
}
}

View File

@ -0,0 +1,58 @@
use std::{io, time::Duration};
use crate::listener::{Listener, Endpoint};
static DURATION: Duration = Duration::from_millis(250);
pub struct Bounced<L> {
listener: L,
}
pub trait BouncedExt: Sized {
fn bounced(self) -> Bounced<Self> {
Bounced { listener: self }
}
}
impl<L> BouncedExt for L { }
fn is_recoverable(e: &io::Error) -> bool {
matches!(e.kind(),
| io::ErrorKind::ConnectionRefused
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::ConnectionReset)
}
impl<L: Listener + Sync> Bounced<L> {
#[inline]
pub async fn accept_next(&self) -> <Self as Listener>::Accept {
loop {
match self.listener.accept().await {
Ok(accept) => return accept,
Err(e) if is_recoverable(&e) => warn!("recoverable connection error: {e}"),
Err(e) => {
warn!("accept error: {e} [retrying in {}ms]", DURATION.as_millis());
tokio::time::sleep(DURATION).await;
}
};
}
}
}
impl<L: Listener + Sync> Listener for Bounced<L> {
type Accept = L::Accept;
type Connection = L::Connection;
async fn accept(&self) -> io::Result<Self::Accept> {
Ok(self.accept_next().await)
}
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> {
self.listener.connect(accept).await
}
fn socket_addr(&self) -> io::Result<Endpoint> {
self.listener.socket_addr()
}
}

View File

@ -0,0 +1,273 @@
use std::io;
use std::time::Duration;
use std::task::{Poll, Context};
use std::pin::Pin;
use tokio::time::{sleep, Sleep};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use futures::{StreamExt, future::{select, Either, Fuse, Future, FutureExt}};
use pin_project_lite::pin_project;
use crate::{config, Shutdown};
use crate::listener::{Listener, Connection, Certificates, Bounced, Endpoint};
// Rocket wraps all connections in a `CancellableIo` struct, an internal
// structure that gracefully closes I/O when it receives a signal. That signal
// is the `shutdown` future. When the future resolves, `CancellableIo` begins to
// terminate in grace, mercy, and finally force close phases. Since all
// connections are wrapped in `CancellableIo`, this eventually ends all I/O.
//
// At that point, unless a user spawned an infinite, stand-alone task that isn't
// monitoring `Shutdown`, all tasks should resolve. This means that all
// instances of the shared `Arc<Rocket>` are dropped and we can return the owned
// instance of `Rocket`.
//
// Unfortunately, the Hyper `server` future resolves as soon as it has finished
// processing requests without respect for ongoing responses. That is, `server`
// resolves even when there are running tasks that are generating a response.
// So, `server` resolving implies little to nothing about the state of
// connections. As a result, we depend on the timing of grace + mercy + some
// buffer to determine when all connections should be closed, thus all tasks
// should be complete, thus all references to `Arc<Rocket>` should be dropped
// and we can get a unique reference.
pin_project! {
pub struct CancellableListener<F, L> {
pub trigger: F,
#[pin]
pub listener: L,
pub grace: Duration,
pub mercy: Duration,
}
}
pin_project! {
/// I/O that can be cancelled when a future `F` resolves.
#[must_use = "futures do nothing unless polled"]
pub struct CancellableIo<F, I> {
#[pin]
io: Option<I>,
#[pin]
trigger: Fuse<F>,
state: State,
grace: Duration,
mercy: Duration,
}
}
enum State {
/// I/O has not been cancelled. Proceed as normal.
Active,
/// I/O has been cancelled. See if we can finish before the timer expires.
Grace(Pin<Box<Sleep>>),
/// Grace period elapsed. Shutdown the connection, waiting for the timer
/// until we force close.
Mercy(Pin<Box<Sleep>>),
}
pub trait CancellableExt: Sized {
fn cancellable(
self,
trigger: Shutdown,
config: &config::Shutdown
) -> CancellableListener<Shutdown, Self> {
if let Some(mut stream) = config.signal_stream() {
let trigger = trigger.clone();
tokio::spawn(async move {
while let Some(sig) = stream.next().await {
if trigger.0.tripped() {
warn!("Received {}. Shutdown already in progress.", sig);
} else {
warn!("Received {}. Requesting shutdown.", sig);
}
trigger.0.trip();
}
});
};
CancellableListener {
trigger,
listener: self,
grace: config.grace(),
mercy: config.mercy(),
}
}
}
impl<L: Listener> CancellableExt for L { }
fn time_out() -> io::Error {
io::Error::new(io::ErrorKind::TimedOut, "Shutdown grace timed out")
}
fn gone() -> io::Error {
io::Error::new(io::ErrorKind::BrokenPipe, "IO driver has terminated")
}
impl<L, F> CancellableListener<F, Bounced<L>>
where L: Listener + Sync,
F: Future + Unpin + Clone + Send + Sync + 'static
{
pub async fn accept_next(&self) -> Option<<Self as Listener>::Accept> {
let next = std::pin::pin!(self.listener.accept_next());
match select(next, self.trigger.clone()).await {
Either::Left((next, _)) => Some(next),
Either::Right(_) => None,
}
}
}
impl<L, F> CancellableListener<F, L>
where L: Listener + Sync,
F: Future + Clone + Send + Sync + 'static
{
fn io<C>(&self, conn: C) -> CancellableIo<F, C> {
CancellableIo {
io: Some(conn),
trigger: self.trigger.clone().fuse(),
state: State::Active,
grace: self.grace,
mercy: self.mercy,
}
}
}
impl<L, F> Listener for CancellableListener<F, L>
where L: Listener + Sync,
F: Future + Clone + Send + Sync + Unpin + 'static
{
type Accept = L::Accept;
type Connection = CancellableIo<F, L::Connection>;
async fn accept(&self) -> io::Result<Self::Accept> {
let accept = std::pin::pin!(self.listener.accept());
match select(accept, self.trigger.clone()).await {
Either::Left((result, _)) => result,
Either::Right(_) => Err(gone()),
}
}
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> {
let conn = std::pin::pin!(self.listener.connect(accept));
match select(conn, self.trigger.clone()).await {
Either::Left((conn, _)) => Ok(self.io(conn?)),
Either::Right(_) => Err(gone()),
}
}
fn socket_addr(&self) -> io::Result<Endpoint> {
self.listener.socket_addr()
}
}
impl<F: Future, I: AsyncWrite> CancellableIo<F, I> {
fn inner(&self) -> Option<&I> {
self.io.as_ref()
}
/// Run `do_io` while connection processing should continue.
fn poll_trigger_then<T>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll<io::Result<T>>,
) -> Poll<io::Result<T>> {
let mut me = self.as_mut().project();
let io = match me.io.as_pin_mut() {
Some(io) => io,
None => return Poll::Ready(Err(gone())),
};
loop {
match me.state {
State::Active => {
if me.trigger.as_mut().poll(cx).is_ready() {
*me.state = State::Grace(Box::pin(sleep(*me.grace)));
} else {
return do_io(io, cx);
}
}
State::Grace(timer) => {
if timer.as_mut().poll(cx).is_ready() {
*me.state = State::Mercy(Box::pin(sleep(*me.mercy)));
} else {
return do_io(io, cx);
}
}
State::Mercy(timer) => {
if timer.as_mut().poll(cx).is_ready() {
self.project().io.set(None);
return Poll::Ready(Err(time_out()));
} else {
let result = futures::ready!(io.poll_shutdown(cx));
self.project().io.set(None);
return match result {
Err(e) => Poll::Ready(Err(e)),
Ok(()) => Poll::Ready(Err(gone()))
};
}
},
}
}
}
}
impl<F: Future, I: AsyncRead + AsyncWrite> AsyncRead for CancellableIo<F, I> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_read(cx, buf))
}
}
impl<F: Future, I: AsyncWrite> AsyncWrite for CancellableIo<F, I> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write(cx, buf))
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<()>> {
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_flush(cx))
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<()>> {
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_shutdown(cx))
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write_vectored(cx, bufs))
}
fn is_write_vectored(&self) -> bool {
self.inner().map(|io| io.is_write_vectored()).unwrap_or(false)
}
}
impl<F: Future, C: Connection> Connection for CancellableIo<F, C>
where F: Unpin + Send + 'static
{
fn peer_address(&self) -> io::Result<Endpoint> {
self.inner()
.ok_or_else(|| gone())
.and_then(|io| io.peer_address())
}
fn peer_certificates(&self) -> Option<Certificates<'_>> {
self.inner().and_then(|io| io.peer_certificates())
}
}

View File

@ -0,0 +1,93 @@
use std::io;
use std::borrow::Cow;
use tokio_util::either::Either;
use tokio::io::{AsyncRead, AsyncWrite};
use super::Endpoint;
/// A collection of raw certificate data.
#[derive(Clone)]
pub struct Certificates<'r>(Cow<'r, [der::CertificateDer<'r>]>);
pub trait Connection: AsyncRead + AsyncWrite + Send + Unpin {
fn peer_address(&self) -> io::Result<Endpoint>;
/// DER-encoded X.509 certificate chain presented by the client, if any.
///
/// The certificate order must be as it appears in the TLS protocol: the
/// first certificate relates to the peer, the second certifies the first,
/// the third certifies the second, and so on.
///
/// Defaults to an empty vector to indicate that no certificates were
/// presented.
fn peer_certificates(&self) -> Option<Certificates<'_>> { None }
}
impl<A: Connection, B: Connection> Connection for Either<A, B> {
fn peer_address(&self) -> io::Result<Endpoint> {
match self {
Either::Left(c) => c.peer_address(),
Either::Right(c) => c.peer_address(),
}
}
fn peer_certificates(&self) -> Option<Certificates<'_>> {
match self {
Either::Left(c) => c.peer_certificates(),
Either::Right(c) => c.peer_certificates(),
}
}
}
impl Certificates<'_> {
pub fn into_owned(self) -> Certificates<'static> {
let cow = self.0.into_iter()
.map(|der| der.clone().into_owned())
.collect::<Vec<_>>()
.into();
Certificates(cow)
}
}
#[cfg(feature = "mtls")]
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
mod der {
use super::*;
pub use crate::mtls::CertificateDer;
impl<'r> Certificates<'r> {
pub(crate) fn inner(&self) -> &[CertificateDer<'r>] {
&self.0
}
}
impl<'r> From<&'r [CertificateDer<'r>]> for Certificates<'r> {
fn from(value: &'r [CertificateDer<'r>]) -> Self {
Certificates(value.into())
}
}
impl From<Vec<CertificateDer<'static>>> for Certificates<'static> {
fn from(value: Vec<CertificateDer<'static>>) -> Self {
Certificates(value.into())
}
}
}
#[cfg(not(feature = "mtls"))]
mod der {
use std::marker::PhantomData;
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
#[derive(Clone)]
pub struct CertificateDer<'r>(PhantomData<&'r [u8]>);
impl CertificateDer<'_> {
pub fn into_owned(self) -> CertificateDer<'static> {
CertificateDer(PhantomData)
}
}
}

View File

@ -0,0 +1,61 @@
use either::Either;
use crate::listener::{Bindable, Endpoint};
use crate::error::{Error, ErrorKind};
#[derive(serde::Deserialize)]
pub struct DefaultListener {
#[serde(default)]
pub address: Endpoint,
pub port: Option<u16>,
pub reuse: Option<bool>,
#[cfg(feature = "tls")]
pub tls: Option<crate::tls::TlsConfig>,
}
#[cfg(not(unix))] type BaseBindable = Either<std::net::SocketAddr, std::net::SocketAddr>;
#[cfg(unix)] type BaseBindable = Either<std::net::SocketAddr, super::unix::UdsConfig>;
#[cfg(not(feature = "tls"))] type TlsBindable<T> = Either<T, T>;
#[cfg(feature = "tls")] type TlsBindable<T> = Either<super::tls::TlsBindable<T>, T>;
impl DefaultListener {
pub(crate) fn base_bindable(&self) -> Result<BaseBindable, crate::Error> {
match &self.address {
Endpoint::Tcp(mut address) => {
self.port.map(|port| address.set_port(port));
Ok(BaseBindable::Left(address))
},
#[cfg(unix)]
Endpoint::Unix(path) => {
let uds = super::unix::UdsConfig { path: path.clone(), reuse: self.reuse, };
Ok(BaseBindable::Right(uds))
},
#[cfg(not(unix))]
Endpoint::Unix(_) => {
let msg = "Unix domain sockets unavailable on non-unix platforms.";
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(Error::new(ErrorKind::Bind(boxed)))
},
other => {
let msg = format!("unsupported default listener address: {other}");
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(Error::new(ErrorKind::Bind(boxed)))
}
}
}
pub(crate) fn tls_bindable<T>(&self, inner: T) -> TlsBindable<T> {
#[cfg(feature = "tls")]
if let Some(tls) = self.tls.clone() {
return TlsBindable::Left(super::tls::TlsBindable { inner, tls });
}
TlsBindable::Right(inner)
}
pub fn bindable(&self) -> Result<impl Bindable, crate::Error> {
self.base_bindable()
.map(|b| b.map_either(|b| self.tls_bindable(b), |b| self.tls_bindable(b)))
}
}

View File

@ -0,0 +1,281 @@
use std::fmt;
use std::path::{Path, PathBuf};
use std::any::Any;
use std::net::{SocketAddr as TcpAddr, Ipv4Addr, AddrParseError};
use std::str::FromStr;
use std::sync::Arc;
use serde::de;
use crate::http::uncased::AsUncased;
pub trait EndpointAddr: fmt::Display + fmt::Debug + Sync + Send + Any { }
impl<T: fmt::Display + fmt::Debug + Sync + Send + Any> EndpointAddr for T {}
#[cfg(not(feature = "tls"))] type TlsInfo = Option<()>;
#[cfg(feature = "tls")] type TlsInfo = Option<crate::tls::TlsConfig>;
/// # Conversions
///
/// * [`&str`] - parse with [`FromStr`]
/// * [`tokio::net::unix::SocketAddr`] - must be path: [`ListenerAddr::Unix`]
/// * [`std::net::SocketAddr`] - infallibly as [ListenerAddr::Tcp]
/// * [`PathBuf`] - infallibly as [`ListenerAddr::Unix`]
// TODO: Rename to something better. `Endpoint`?
#[derive(Debug)]
pub enum Endpoint {
Tcp(TcpAddr),
Unix(PathBuf),
Tls(Arc<Endpoint>, TlsInfo),
Custom(Arc<dyn EndpointAddr>),
}
impl Endpoint {
pub fn new<T: EndpointAddr>(value: T) -> Endpoint {
Endpoint::Custom(Arc::new(value))
}
pub fn tcp(&self) -> Option<TcpAddr> {
match self {
Endpoint::Tcp(addr) => Some(*addr),
_ => None,
}
}
pub fn unix(&self) -> Option<&Path> {
match self {
Endpoint::Unix(addr) => Some(addr),
_ => None,
}
}
pub fn tls(&self) -> Option<&Endpoint> {
match self {
Endpoint::Tls(addr, _) => Some(addr),
_ => None,
}
}
#[cfg(feature = "tls")]
pub fn tls_config(&self) -> Option<&crate::tls::TlsConfig> {
match self {
Endpoint::Tls(_, Some(ref config)) => Some(config),
_ => None,
}
}
#[cfg(feature = "mtls")]
pub fn mtls_config(&self) -> Option<&crate::mtls::MtlsConfig> {
match self {
Endpoint::Tls(_, Some(config)) => config.mutual(),
_ => None,
}
}
pub fn downcast<T: 'static>(&self) -> Option<&T> {
match self {
Endpoint::Tcp(addr) => (&*addr as &dyn Any).downcast_ref(),
Endpoint::Unix(addr) => (&*addr as &dyn Any).downcast_ref(),
Endpoint::Custom(addr) => (&*addr as &dyn Any).downcast_ref(),
Endpoint::Tls(inner, ..) => inner.downcast(),
}
}
pub fn is_tcp(&self) -> bool {
self.tcp().is_some()
}
pub fn is_unix(&self) -> bool {
self.unix().is_some()
}
pub fn is_tls(&self) -> bool {
self.tls().is_some()
}
#[cfg(feature = "tls")]
pub fn with_tls(self, config: crate::tls::TlsConfig) -> Endpoint {
if self.is_tls() {
return self;
}
Self::Tls(Arc::new(self), Some(config))
}
pub fn assume_tls(self) -> Endpoint {
if self.is_tls() {
return self;
}
Self::Tls(Arc::new(self), None)
}
}
impl fmt::Display for Endpoint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use Endpoint::*;
match self {
Tcp(addr) => write!(f, "http://{addr}"),
Unix(addr) => write!(f, "unix:{}", addr.display()),
Custom(inner) => inner.fmt(f),
Tls(inner, c) => match (&**inner, c.as_ref()) {
#[cfg(feature = "mtls")]
(Tcp(i), Some(c)) if c.mutual().is_some() => write!(f, "https://{i} (TLS + MTLS)"),
(Tcp(i), _) => write!(f, "https://{i} (TLS)"),
#[cfg(feature = "mtls")]
(i, Some(c)) if c.mutual().is_some() => write!(f, "{i} (TLS + MTLS)"),
(inner, _) => write!(f, "{inner} (TLS)"),
},
}
}
}
impl From<std::net::SocketAddr> for Endpoint {
fn from(value: std::net::SocketAddr) -> Self {
Self::Tcp(value)
}
}
impl From<std::net::SocketAddrV4> for Endpoint {
fn from(value: std::net::SocketAddrV4) -> Self {
Self::Tcp(value.into())
}
}
impl From<std::net::SocketAddrV6> for Endpoint {
fn from(value: std::net::SocketAddrV6) -> Self {
Self::Tcp(value.into())
}
}
impl From<PathBuf> for Endpoint {
fn from(value: PathBuf) -> Self {
Self::Unix(value)
}
}
#[cfg(unix)]
impl TryFrom<tokio::net::unix::SocketAddr> for Endpoint {
type Error = std::io::Error;
fn try_from(v: tokio::net::unix::SocketAddr) -> Result<Self, Self::Error> {
v.as_pathname()
.ok_or_else(|| std::io::Error::other("unix socket is not path"))
.map(|path| Endpoint::Unix(path.to_path_buf()))
}
}
impl TryFrom<&str> for Endpoint {
type Error = AddrParseError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
value.parse()
}
}
impl Default for Endpoint {
fn default() -> Self {
Endpoint::Tcp(TcpAddr::new(Ipv4Addr::LOCALHOST.into(), 8000))
}
}
/// Parses an address into a `ListenerAddr`.
///
/// The syntax is:
///
/// ```text
/// listener_addr = 'tcp' ':' tcp_addr | 'unix' ':' unix_addr | tcp_addr
/// tcp_addr := IP_ADDR | SOCKET_ADDR
/// unix_addr := PATH
///
/// IP_ADDR := `std::net::IpAddr` string as defined by Rust
/// SOCKET_ADDR := `std::net::SocketAddr` string as defined by Rust
/// PATH := `PathBuf` (any UTF-8) string as defined by Rust
/// ```
///
/// If `IP_ADDR` is specified, the port defaults to `8000`.
impl FromStr for Endpoint {
type Err = AddrParseError;
fn from_str(string: &str) -> Result<Self, Self::Err> {
fn parse_tcp(string: &str, def_port: u16) -> Result<TcpAddr, AddrParseError> {
string.parse().or_else(|_| string.parse().map(|ip| TcpAddr::new(ip, def_port)))
}
if let Some((proto, string)) = string.split_once(':') {
if proto.trim().as_uncased() == "tcp" {
return parse_tcp(string.trim(), 8000).map(Self::Tcp);
} else if proto.trim().as_uncased() == "unix" {
return Ok(Self::Unix(PathBuf::from(string.trim())));
}
}
parse_tcp(string.trim(), 8000).map(Self::Tcp)
}
}
impl<'de> de::Deserialize<'de> for Endpoint {
fn deserialize<D: de::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
struct Visitor;
impl<'de> de::Visitor<'de> for Visitor {
type Value = Endpoint;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("TCP or Unix address")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
v.parse::<Endpoint>().map_err(|e| E::custom(e.to_string()))
}
}
de.deserialize_any(Visitor)
}
}
impl Eq for Endpoint { }
impl PartialEq for Endpoint {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Tcp(l0), Self::Tcp(r0)) => l0 == r0,
(Self::Unix(l0), Self::Unix(r0)) => l0 == r0,
(Self::Tls(l0, _), Self::Tls(r0, _)) => l0 == r0,
(Self::Custom(l0), Self::Custom(r0)) => l0.to_string() == r0.to_string(),
_ => false,
}
}
}
impl PartialEq<std::net::SocketAddr> for Endpoint {
fn eq(&self, other: &std::net::SocketAddr) -> bool {
self.tcp() == Some(*other)
}
}
impl PartialEq<std::net::SocketAddrV4> for Endpoint {
fn eq(&self, other: &std::net::SocketAddrV4) -> bool {
self.tcp() == Some((*other).into())
}
}
impl PartialEq<std::net::SocketAddrV6> for Endpoint {
fn eq(&self, other: &std::net::SocketAddrV6) -> bool {
self.tcp() == Some((*other).into())
}
}
impl PartialEq<PathBuf> for Endpoint {
fn eq(&self, other: &PathBuf) -> bool {
self.unix() == Some(other.as_path())
}
}
impl PartialEq<Path> for Endpoint {
fn eq(&self, other: &Path) -> bool {
self.unix() == Some(other)
}
}

View File

@ -0,0 +1,65 @@
use std::io;
use futures::TryFutureExt;
use tokio_util::either::Either;
use crate::listener::{Connection, Endpoint};
pub trait Listener: Send + Sync {
type Accept: Send;
type Connection: Connection;
async fn accept(&self) -> io::Result<Self::Accept>;
#[crate::async_bound(Send)]
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection>;
fn socket_addr(&self) -> io::Result<Endpoint>;
}
impl<L: Listener> Listener for &L {
type Accept = L::Accept;
type Connection = L::Connection;
async fn accept(&self) -> io::Result<Self::Accept> {
<L as Listener>::accept(self).await
}
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> {
<L as Listener>::connect(self, accept).await
}
fn socket_addr(&self) -> io::Result<Endpoint> {
<L as Listener>::socket_addr(self)
}
}
impl<A: Listener, B: Listener> Listener for Either<A, B> {
type Accept = Either<A::Accept, B::Accept>;
type Connection = Either<A::Connection, B::Connection>;
async fn accept(&self) -> io::Result<Self::Accept> {
match self {
Either::Left(l) => l.accept().map_ok(Either::Left).await,
Either::Right(l) => l.accept().map_ok(Either::Right).await,
}
}
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> {
match (self, accept) {
(Either::Left(l), Either::Left(a)) => l.connect(a).map_ok(Either::Left).await,
(Either::Right(l), Either::Right(a)) => l.connect(a).map_ok(Either::Right).await,
_ => unreachable!()
}
}
fn socket_addr(&self) -> io::Result<Endpoint> {
match self {
Either::Left(l) => l.socket_addr(),
Either::Right(l) => l.socket_addr(),
}
}
}

View File

@ -0,0 +1,24 @@
mod cancellable;
mod bounced;
mod listener;
mod endpoint;
mod connection;
mod bindable;
mod default;
#[cfg(unix)]
#[cfg_attr(nightly, doc(cfg(unix)))]
pub mod unix;
#[cfg(feature = "tls")]
#[cfg_attr(nightly, doc(cfg(feature = "tls")))]
pub mod tls;
pub mod tcp;
pub use endpoint::*;
pub use listener::*;
pub use connection::*;
pub use bindable::*;
pub use default::*;
pub(crate) use cancellable::*;
pub(crate) use bounced::*;

View File

@ -0,0 +1,43 @@
use std::io;
#[doc(inline)]
pub use tokio::net::{TcpListener, TcpStream};
use crate::listener::{Listener, Bindable, Connection, Endpoint};
impl Bindable for std::net::SocketAddr {
type Listener = TcpListener;
type Error = io::Error;
async fn bind(self) -> Result<Self::Listener, Self::Error> {
TcpListener::bind(self).await
}
}
impl Listener for TcpListener {
type Accept = Self::Connection;
type Connection = TcpStream;
async fn accept(&self) -> io::Result<Self::Accept> {
let conn = self.accept().await?.0;
let _ = conn.set_nodelay(true);
let _ = conn.set_linger(None);
Ok(conn)
}
async fn connect(&self, conn: Self::Connection) -> io::Result<Self::Connection> {
Ok(conn)
}
fn socket_addr(&self) -> io::Result<Endpoint> {
self.local_addr().map(Endpoint::Tcp)
}
}
impl Connection for TcpStream {
fn peer_address(&self) -> io::Result<Endpoint> {
self.peer_addr().map(Endpoint::Tcp)
}
}

View File

@ -0,0 +1,116 @@
use std::io;
use std::sync::Arc;
use serde::Deserialize;
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};
use tokio_rustls::TlsAcceptor;
use crate::tls::{TlsConfig, Error};
use crate::tls::util::{load_cert_chain, load_key, load_ca_certs};
use crate::listener::{Listener, Bindable, Connection, Certificates, Endpoint};
#[doc(inline)]
pub use tokio_rustls::server::TlsStream;
/// A TLS listener over some listener interface L.
pub struct TlsListener<L> {
listener: L,
acceptor: TlsAcceptor,
config: TlsConfig,
}
#[derive(Clone, Deserialize)]
pub struct TlsBindable<I> {
#[serde(flatten)]
pub inner: I,
pub tls: TlsConfig,
}
impl TlsConfig {
pub(crate) fn acceptor(&self) -> Result<tokio_rustls::TlsAcceptor, Error> {
let provider = rustls::crypto::CryptoProvider {
cipher_suites: self.ciphers().map(|c| c.into()).collect(),
..rustls::crypto::ring::default_provider()
};
#[cfg(feature = "mtls")]
let verifier = match self.mutual {
Some(ref mtls) => {
let ca_certs = load_ca_certs(&mut mtls.ca_certs_reader()?)?;
let verifier = WebPkiClientVerifier::builder(Arc::new(ca_certs));
match mtls.mandatory {
true => verifier.build()?,
false => verifier.allow_unauthenticated().build()?,
}
},
None => WebPkiClientVerifier::no_client_auth(),
};
#[cfg(not(feature = "mtls"))]
let verifier = WebPkiClientVerifier::no_client_auth();
let key = load_key(&mut self.key_reader()?)?;
let cert_chain = load_cert_chain(&mut self.certs_reader()?)?;
let mut tls_config = ServerConfig::builder_with_provider(Arc::new(provider))
.with_safe_default_protocol_versions()?
.with_client_cert_verifier(verifier)
.with_single_cert(cert_chain, key)?;
tls_config.ignore_client_order = self.prefer_server_cipher_order;
tls_config.session_storage = ServerSessionMemoryCache::new(1024);
tls_config.ticketer = rustls::crypto::ring::Ticketer::new()?;
tls_config.alpn_protocols = vec![b"http/1.1".to_vec()];
if cfg!(feature = "http2") {
tls_config.alpn_protocols.insert(0, b"h2".to_vec());
}
Ok(TlsAcceptor::from(Arc::new(tls_config)))
}
}
impl<I: Bindable> Bindable for TlsBindable<I> {
type Listener = TlsListener<I::Listener>;
type Error = Error;
async fn bind(self) -> Result<Self::Listener, Self::Error> {
Ok(TlsListener {
acceptor: self.tls.acceptor()?,
listener: self.inner.bind().await.map_err(|e| Error::Bind(Box::new(e)))?,
config: self.tls,
})
}
}
impl<L: Listener + Sync> Listener for TlsListener<L>
where L::Connection: Unpin
{
type Accept = L::Accept;
type Connection = TlsStream<L::Connection>;
async fn accept(&self) -> io::Result<Self::Accept> {
self.listener.accept().await
}
async fn connect(&self, accept: L::Accept) -> io::Result<Self::Connection> {
let conn = self.listener.connect(accept).await?;
self.acceptor.accept(conn).await
}
fn socket_addr(&self) -> io::Result<Endpoint> {
Ok(self.listener.socket_addr()?.with_tls(self.config.clone()))
}
}
impl<C: Connection + Unpin> Connection for TlsStream<C> {
fn peer_address(&self) -> io::Result<Endpoint> {
Ok(self.get_ref().0.peer_address()?.assume_tls())
}
#[cfg(feature = "mtls")]
fn peer_certificates(&self) -> Option<Certificates<'_>> {
let cert_chain = self.get_ref().1.peer_certificates()?;
Some(Certificates::from(cert_chain))
}
}

View File

@ -0,0 +1,107 @@
use std::io;
use std::path::PathBuf;
use tokio::time::{sleep, Duration};
use crate::fs::NamedFile;
use crate::listener::{Listener, Bindable, Connection, Endpoint};
use crate::util::unix;
pub use tokio::net::UnixStream;
#[derive(Debug, Clone)]
pub struct UdsConfig {
/// Socket address.
pub path: PathBuf,
/// Recreate a socket that already exists.
pub reuse: Option<bool>,
}
pub struct UdsListener {
path: PathBuf,
lock: Option<NamedFile>,
listener: tokio::net::UnixListener,
}
impl Bindable for UdsConfig {
type Listener = UdsListener;
type Error = io::Error;
async fn bind(self) -> Result<Self::Listener, Self::Error> {
let lock = if self.reuse.unwrap_or(true) {
let lock_ext = match self.path.extension().and_then(|s| s.to_str()) {
Some(ext) if !ext.is_empty() => format!("{}.lock", ext),
_ => "lock".to_string()
};
let mut opts = tokio::fs::File::options();
opts.create(true).write(true);
let lock_path = self.path.with_extension(lock_ext);
let lock_file = NamedFile::open_with(lock_path, &opts).await?;
unix::lock_exlusive_nonblocking(lock_file.file())?;
if self.path.exists() {
tokio::fs::remove_file(&self.path).await?;
}
Some(lock_file)
} else {
None
};
// Sometimes, we get `AddrInUse`, even though we've tried deleting the
// socket. If all is well, eventually the socket will _really_ be gone,
// and this will succeed. So let's try a few times.
let mut retries = 5;
let listener = loop {
match tokio::net::UnixListener::bind(&self.path) {
Ok(listener) => break listener,
Err(e) if self.path.exists() && lock.is_none() => return Err(e),
Err(_) if retries > 0 => {
retries -= 1;
sleep(Duration::from_millis(100)).await;
},
Err(e) => return Err(e),
}
};
Ok(UdsListener { lock, listener, path: self.path, })
}
}
impl Listener for UdsListener {
type Accept = UnixStream;
type Connection = Self::Accept;
async fn accept(&self) -> io::Result<Self::Accept> {
Ok(self.listener.accept().await?.0)
}
async fn connect(&self, accept:Self::Accept) -> io::Result<Self::Connection> {
Ok(accept)
}
fn socket_addr(&self) -> io::Result<Endpoint> {
self.listener.local_addr()?.try_into()
}
}
impl Connection for UnixStream {
fn peer_address(&self) -> io::Result<Endpoint> {
self.local_addr()?.try_into()
}
}
impl Drop for UdsListener {
fn drop(&mut self) {
if let Some(lock) = &self.lock {
let _ = std::fs::remove_file(&self.path);
let _ = std::fs::remove_file(lock.path());
let _ = unix::unlock_nonblocking(lock.file());
} else {
let _ = std::fs::remove_file(&self.path);
}
}
}

View File

@ -4,7 +4,8 @@ use parking_lot::RwLock;
use crate::{Rocket, Phase, Orbit, Ignite, Error}; use crate::{Rocket, Phase, Orbit, Ignite, Error};
use crate::local::asynchronous::{LocalRequest, LocalResponse}; use crate::local::asynchronous::{LocalRequest, LocalResponse};
use crate::http::{Method, uri::Origin, private::cookie}; use crate::http::{Method, uri::Origin};
use crate::listener::Endpoint;
/// An `async` client to construct and dispatch local requests. /// An `async` client to construct and dispatch local requests.
/// ///
@ -55,9 +56,15 @@ pub struct Client {
impl Client { impl Client {
pub(crate) async fn _new<P: Phase>( pub(crate) async fn _new<P: Phase>(
rocket: Rocket<P>, rocket: Rocket<P>,
tracked: bool tracked: bool,
secure: bool,
) -> Result<Client, Error> { ) -> Result<Client, Error> {
let rocket = rocket.local_launch().await?; let mut listener = Endpoint::new("local client");
if secure {
listener = listener.assume_tls();
}
let rocket = rocket.local_launch(listener).await?;
let cookies = RwLock::new(cookie::CookieJar::new()); let cookies = RwLock::new(cookie::CookieJar::new());
Ok(Client { rocket, cookies, tracked }) Ok(Client { rocket, cookies, tracked })
} }

View File

@ -23,7 +23,7 @@ use super::{Client, LocalResponse};
/// let client = Client::tracked(rocket::build()).await.expect("valid rocket"); /// let client = Client::tracked(rocket::build()).await.expect("valid rocket");
/// let req = client.post("/") /// let req = client.post("/")
/// .header(ContentType::JSON) /// .header(ContentType::JSON)
/// .remote("127.0.0.1:8000".parse().unwrap()) /// .remote("127.0.0.1:8000")
/// .cookie(("name", "value")) /// .cookie(("name", "value"))
/// .body(r#"{ "value": 42 }"#); /// .body(r#"{ "value": 42 }"#);
/// ///
@ -86,14 +86,14 @@ impl<'c> LocalRequest<'c> {
if self.inner().uri() == invalid { if self.inner().uri() == invalid {
error!("invalid request URI: {:?}", invalid.path()); error!("invalid request URI: {:?}", invalid.path());
return LocalResponse::new(self.request, move |req| { return LocalResponse::new(self.request, move |req| {
rocket.handle_error(Status::BadRequest, req) rocket.dispatch_error(Status::BadRequest, req)
}).await }).await
} }
} }
// Actually dispatch the request. // Actually dispatch the request.
let mut data = Data::local(self.data); let mut data = Data::local(self.data);
let token = rocket.preprocess_request(&mut self.request, &mut data).await; let token = rocket.preprocess(&mut self.request, &mut data).await;
let response = LocalResponse::new(self.request, move |req| { let response = LocalResponse::new(self.request, move |req| {
rocket.dispatch(token, req, data) rocket.dispatch(token, req, data)
}).await; }).await;

View File

@ -53,9 +53,14 @@ use crate::{Request, Response};
/// ///
/// For more, see [the top-level documentation](../index.html#localresponse). /// For more, see [the top-level documentation](../index.html#localresponse).
pub struct LocalResponse<'c> { pub struct LocalResponse<'c> {
_request: Box<Request<'c>>, // XXX: SAFETY: This (dependent) field must come first due to drop order!
response: Response<'c>, response: Response<'c>,
cookies: CookieJar<'c>, cookies: CookieJar<'c>,
_request: Box<Request<'c>>,
}
impl Drop for LocalResponse<'_> {
fn drop(&mut self) { }
} }
impl<'c> LocalResponse<'c> { impl<'c> LocalResponse<'c> {
@ -64,7 +69,8 @@ impl<'c> LocalResponse<'c> {
O: Future<Output = Response<'c>> + Send O: Future<Output = Response<'c>> + Send
{ {
// `LocalResponse` is a self-referential structure. In particular, // `LocalResponse` is a self-referential structure. In particular,
// `inner` can refer to `_request` and its contents. As such, we must // `response` and `cookies` can refer to `_request` and its contents. As
// such, we must
// 1) Ensure `Request` has a stable address. // 1) Ensure `Request` has a stable address.
// //
// This is done by `Box`ing the `Request`, using only the stable // This is done by `Box`ing the `Request`, using only the stable
@ -97,7 +103,7 @@ impl<'c> LocalResponse<'c> {
cookies.add_original(cookie.into_owned()); cookies.add_original(cookie.into_owned());
} }
LocalResponse { cookies, _request: boxed_req, response, } LocalResponse { _request: boxed_req, cookies, response, }
} }
} }
} }

View File

@ -30,7 +30,7 @@ pub struct Client {
} }
impl Client { impl Client {
fn _new<P: Phase>(rocket: Rocket<P>, tracked: bool) -> Result<Client, Error> { fn _new<P: Phase>(rocket: Rocket<P>, tracked: bool, secure: bool) -> Result<Client, Error> {
let runtime = tokio::runtime::Builder::new_multi_thread() let runtime = tokio::runtime::Builder::new_multi_thread()
.thread_name("rocket-local-client-worker-thread") .thread_name("rocket-local-client-worker-thread")
.worker_threads(1) .worker_threads(1)
@ -39,7 +39,7 @@ impl Client {
.expect("create tokio runtime"); .expect("create tokio runtime");
// Initialize the Rocket instance // Initialize the Rocket instance
let inner = Some(runtime.block_on(asynchronous::Client::_new(rocket, tracked))?); let inner = Some(runtime.block_on(asynchronous::Client::_new(rocket, tracked, secure))?);
Ok(Self { inner, runtime: RefCell::new(runtime) }) Ok(Self { inner, runtime: RefCell::new(runtime) })
} }
@ -73,7 +73,7 @@ impl Client {
#[inline(always)] #[inline(always)]
pub(crate) fn _with_raw_cookies<F, T>(&self, f: F) -> T pub(crate) fn _with_raw_cookies<F, T>(&self, f: F) -> T
where F: FnOnce(&crate::http::private::cookie::CookieJar) -> T where F: FnOnce(&cookie::CookieJar) -> T
{ {
self.inner()._with_raw_cookies(f) self.inner()._with_raw_cookies(f)
} }

View File

@ -21,7 +21,7 @@ use super::{Client, LocalResponse};
/// let client = Client::tracked(rocket::build()).expect("valid rocket"); /// let client = Client::tracked(rocket::build()).expect("valid rocket");
/// let req = client.post("/") /// let req = client.post("/")
/// .header(ContentType::JSON) /// .header(ContentType::JSON)
/// .remote("127.0.0.1:8000".parse().unwrap()) /// .remote("127.0.0.1:8000")
/// .cookie(("name", "value")) /// .cookie(("name", "value"))
/// .body(r#"{ "value": 42 }"#); /// .body(r#"{ "value": 42 }"#);
/// ///

View File

@ -68,7 +68,12 @@ macro_rules! pub_client_impl {
/// ``` /// ```
#[inline(always)] #[inline(always)]
pub $($prefix)? fn tracked<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> { pub $($prefix)? fn tracked<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
Self::_new(rocket, true) $(.$suffix)? Self::_new(rocket, true, false) $(.$suffix)?
}
#[inline(always)]
pub $($prefix)? fn tracked_secure<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
Self::_new(rocket, true, true) $(.$suffix)?
} }
/// Construct a new `Client` from an instance of `Rocket` _without_ /// Construct a new `Client` from an instance of `Rocket` _without_
@ -92,7 +97,11 @@ macro_rules! pub_client_impl {
/// let client = Client::untracked(rocket); /// let client = Client::untracked(rocket);
/// ``` /// ```
pub $($prefix)? fn untracked<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> { pub $($prefix)? fn untracked<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
Self::_new(rocket, false) $(.$suffix)? Self::_new(rocket, false, false) $(.$suffix)?
}
pub $($prefix)? fn untracked_secure<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
Self::_new(rocket, false, true) $(.$suffix)?
} }
/// Terminates `Client` by initiating a graceful shutdown via /// Terminates `Client` by initiating a graceful shutdown via
@ -135,15 +144,6 @@ macro_rules! pub_client_impl {
Self::tracked(rocket.configure(figment)) $(.$suffix)? Self::tracked(rocket.configure(figment)) $(.$suffix)?
} }
/// Deprecated alias to [`Client::tracked()`].
#[deprecated(
since = "0.6.0-dev",
note = "choose between `Client::untracked()` and `Client::tracked()`"
)]
pub $($prefix)? fn new<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
Self::tracked(rocket) $(.$suffix)?
}
/// Returns a reference to the `Rocket` this client is creating requests /// Returns a reference to the `Rocket` this client is creating requests
/// for. /// for.
/// ///

View File

@ -97,24 +97,40 @@ macro_rules! pub_request_impl {
self._request_mut().add_header(header.into()); self._request_mut().add_header(header.into());
} }
/// Set the remote address of this request. /// Set the remote address of this request to `address`.
///
/// `address` may be any type that [can be converted into a `ListenerAddr`].
/// If `address` fails to convert, the remote is left unchanged.
///
/// [can be converted into a `ListenerAddr`]: crate::listener::ListenerAddr#conversions
/// ///
/// # Examples /// # Examples
/// ///
/// Set the remote address to "8.8.8.8:80": /// Set the remote address to "8.8.8.8:80":
/// ///
/// ```rust /// ```rust
/// use std::net::{SocketAddrV4, Ipv4Addr};
///
#[doc = $import] #[doc = $import]
/// ///
/// # Client::_test(|_, request, _| { /// # Client::_test(|_, request, _| {
/// let request: LocalRequest = request; /// let request: LocalRequest = request;
/// let address = "8.8.8.8:80".parse().unwrap(); /// let req = request.remote("8.8.8.8:80");
/// let req = request.remote(address); ///
/// let addr = SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8).into(), 80);
/// assert_eq!(req.inner().remote().unwrap(), &addr);
/// # }); /// # });
/// ``` /// ```
#[inline] #[inline]
pub fn remote(mut self, address: std::net::SocketAddr) -> Self { pub fn remote<T>(mut self, endpoint: T) -> Self
self.set_remote(address); where T: TryInto<crate::listener::Endpoint>
{
if let Ok(endpoint) = endpoint.try_into() {
self.set_remote(endpoint);
} else {
warn!("remote failed to convert");
}
self self
} }
@ -228,11 +244,13 @@ macro_rules! pub_request_impl {
#[cfg(feature = "mtls")] #[cfg(feature = "mtls")]
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))] #[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
pub fn identity<C: std::io::Read>(mut self, reader: C) -> Self { pub fn identity<C: std::io::Read>(mut self, reader: C) -> Self {
use crate::http::{tls::util::load_cert_chain, private::Certificates}; use std::sync::Arc;
use crate::tls::util::load_cert_chain;
use crate::listener::Certificates;
let mut reader = std::io::BufReader::new(reader); let mut reader = std::io::BufReader::new(reader);
let certs = load_cert_chain(&mut reader).map(Certificates::from); let certs = load_cert_chain(&mut reader).map(Certificates::from);
self._request_mut().connection.client_certificates = certs.ok(); self._request_mut().connection.peer_certs = certs.ok().map(Arc::new);
self self
} }

View File

@ -1,25 +0,0 @@
//! Support for mutual TLS client certificates.
//!
//! For details on how to configure mutual TLS, see
//! [`MutualTls`](crate::config::MutualTls) and the [TLS
//! guide](https://rocket.rs/master/guide/configuration/#tls). See
//! [`Certificate`] for a request guard that validated, verifies, and retrieves
//! client certificates.
#[doc(inline)]
pub use crate::http::tls::mtls::*;
use crate::request::{Request, FromRequest, Outcome};
use crate::outcome::{try_outcome, IntoOutcome};
use crate::http::Status;
#[crate::async_trait]
impl<'r> FromRequest<'r> for Certificate<'r> {
type Error = Error;
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let certs = req.connection.client_certificates.as_ref().or_forward(Status::Unauthorized);
let data = try_outcome!(try_outcome!(certs).chain_data().or_forward(Status::Unauthorized));
Certificate::parse(data).or_error(Status::Unauthorized)
}
}

View File

@ -1,51 +1,8 @@
pub mod oid {
//! Lower-level OID types re-exported from
//! [`oid_registry`](https://docs.rs/oid-registry/0.4) and
//! [`der-parser`](https://docs.rs/der-parser/7).
pub use x509_parser::oid_registry::*;
pub use x509_parser::objects::*;
}
pub mod bigint {
//! Signed and unsigned big integer types re-exported from
//! [`num_bigint`](https://docs.rs/num-bigint/0.4).
pub use x509_parser::der_parser::num_bigint::*;
}
pub mod x509 {
//! Lower-level X.509 types re-exported from
//! [`x509_parser`](https://docs.rs/x509-parser/0.13).
//!
//! Lack of documentation is directly inherited from the source crate.
//! Prefer to use Rocket's wrappers when possible.
pub use x509_parser::certificate::*;
pub use x509_parser::cri_attributes::*;
pub use x509_parser::error::*;
pub use x509_parser::extensions::*;
pub use x509_parser::revocation_list::*;
pub use x509_parser::time::*;
pub use x509_parser::x509::*;
pub use x509_parser::der_parser::der;
pub use x509_parser::der_parser::ber;
pub use x509_parser::traits::*;
}
use std::fmt;
use std::ops::Deref;
use std::num::NonZeroUsize;
use ref_cast::RefCast; use ref_cast::RefCast;
use x509_parser::nom;
use x509::{ParsedExtension, X509Name, X509Certificate, TbsCertificate, X509Error, FromDer};
use oid::OID_X509_EXT_SUBJECT_ALT_NAME as SUBJECT_ALT_NAME;
use crate::listener::CertificateDer; use crate::mtls::{x509, oid, bigint, Name, Result, Error};
use crate::request::{Request, FromRequest, Outcome};
/// A type alias for [`Result`](std::result::Result) with the error type set to use crate::http::Status;
/// [`Error`].
pub type Result<T, E = Error> = std::result::Result<T, E>;
/// A request guard for validated, verified client certificates. /// A request guard for validated, verified client certificates.
/// ///
@ -143,60 +100,42 @@ pub type Result<T, E = Error> = std::result::Result<T, E>;
/// ``` /// ```
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub struct Certificate<'a> { pub struct Certificate<'a> {
x509: X509Certificate<'a>, x509: x509::X509Certificate<'a>,
data: &'a CertificateDer, data: &'a CertificateDer<'a>,
} }
/// An X.509 Distinguished Name (DN) found in a [`Certificate`]. pub use rustls::pki_types::CertificateDer;
///
/// This type is a wrapper over [`x509::X509Name`] with convenient methods and
/// complete documentation. Should the data exposed by the inherent methods not
/// suffice, this type derefs to [`x509::X509Name`].
#[repr(transparent)]
#[derive(Debug, PartialEq, RefCast)]
pub struct Name<'a>(X509Name<'a>);
/// An error returned by the [`Certificate`] request guard. #[crate::async_trait]
/// impl<'r> FromRequest<'r> for Certificate<'r> {
/// To retrieve this error in a handler, use an `mtls::Result<Certificate>` type Error = Error;
/// guard type:
/// async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
/// ```rust use crate::outcome::{try_outcome, IntoOutcome};
/// # extern crate rocket;
/// # use rocket::get; let certs = req.connection
/// use rocket::mtls::{self, Certificate}; .peer_certs
/// .as_ref()
/// #[get("/auth")] .or_forward(Status::Unauthorized);
/// fn auth(cert: mtls::Result<Certificate<'_>>) {
/// match cert { let chain = try_outcome!(certs);
/// Ok(cert) => { /* do something with the client cert */ }, Certificate::parse(chain.inner()).or_error(Status::Unauthorized)
/// Err(e) => { /* do something with the error */ }, }
/// }
/// }
/// ```
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Error {
/// The certificate chain presented by the client had no certificates.
Empty,
/// The certificate contained neither a subject nor a subjectAlt extension.
NoSubject,
/// There is no subject and the subjectAlt is not marked as critical.
NonCriticalSubjectAlt,
/// An error occurred while parsing the certificate.
Parse(X509Error),
/// The certificate parsed partially but is incomplete.
///
/// If `Some(n)`, then `n` more bytes were expected. Otherwise, the number
/// of expected bytes is unknown.
Incomplete(Option<NonZeroUsize>),
/// The certificate contained `.0` bytes of trailing data.
Trailing(usize),
} }
impl<'a> Certificate<'a> { impl<'a> Certificate<'a> {
fn parse_one(raw: &[u8]) -> Result<X509Certificate<'_>> { /// PRIVATE: For internal Rocket use only!
let (left, x509) = X509Certificate::from_der(raw)?; fn parse<'r>(chain: &'r [CertificateDer<'r>]) -> Result<Certificate<'r>> {
let data = chain.first().ok_or_else(|| Error::Empty)?;
let x509 = Certificate::parse_one(&*data)?;
Ok(Certificate { x509, data })
}
fn parse_one(raw: &[u8]) -> Result<x509::X509Certificate<'_>> {
use oid::OID_X509_EXT_SUBJECT_ALT_NAME as SUBJECT_ALT_NAME;
use x509_parser::traits::FromDer;
let (left, x509) = x509::X509Certificate::from_der(raw)?;
if !left.is_empty() { if !left.is_empty() {
return Err(Error::Trailing(left.len())); return Err(Error::Trailing(left.len()));
} }
@ -204,7 +143,7 @@ impl<'a> Certificate<'a> {
// Ensure we have a subject or a subjectAlt. // Ensure we have a subject or a subjectAlt.
if x509.subject().as_raw().is_empty() { if x509.subject().as_raw().is_empty() {
if let Some(ext) = x509.extensions().iter().find(|e| e.oid == SUBJECT_ALT_NAME) { if let Some(ext) = x509.extensions().iter().find(|e| e.oid == SUBJECT_ALT_NAME) {
if !matches!(ext.parsed_extension(), ParsedExtension::SubjectAlternativeName(..)) { if let x509::ParsedExtension::SubjectAlternativeName(..) = ext.parsed_extension() {
return Err(Error::NoSubject); return Err(Error::NoSubject);
} else if !ext.critical { } else if !ext.critical {
return Err(Error::NonCriticalSubjectAlt); return Err(Error::NonCriticalSubjectAlt);
@ -218,18 +157,10 @@ impl<'a> Certificate<'a> {
} }
#[inline(always)] #[inline(always)]
fn inner(&self) -> &TbsCertificate<'a> { fn inner(&self) -> &x509::TbsCertificate<'a> {
&self.x509.tbs_certificate &self.x509.tbs_certificate
} }
/// PRIVATE: For internal Rocket use only!
#[doc(hidden)]
pub fn parse(chain: &[CertificateDer]) -> Result<Certificate<'_>> {
let data = chain.first().ok_or_else(|| Error::Empty)?;
let x509 = Certificate::parse_one(&data.0)?;
Ok(Certificate { x509, data })
}
/// Returns the serial number of the X.509 certificate. /// Returns the serial number of the X.509 certificate.
/// ///
/// # Example /// # Example
@ -387,176 +318,14 @@ impl<'a> Certificate<'a> {
/// } /// }
/// ``` /// ```
pub fn as_bytes(&self) -> &'a [u8] { pub fn as_bytes(&self) -> &'a [u8] {
&self.data.0 &*self.data
} }
} }
impl<'a> Deref for Certificate<'a> { impl<'a> std::ops::Deref for Certificate<'a> {
type Target = TbsCertificate<'a>; type Target = x509::TbsCertificate<'a>;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
self.inner() self.inner()
} }
} }
impl<'a> Name<'a> {
/// Returns the _first_ UTF-8 _string_ common name, if any.
///
/// Note that common names need not be UTF-8 strings, or strings at all.
/// This method returns the first common name attribute that is.
///
/// # Example
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// use rocket::mtls::Certificate;
///
/// #[get("/auth")]
/// fn auth(cert: Certificate<'_>) {
/// if let Some(name) = cert.subject().common_name() {
/// println!("Hello, {}!", name);
/// }
/// }
/// ```
pub fn common_name(&self) -> Option<&'a str> {
self.common_names().next()
}
/// Returns an iterator over all of the UTF-8 _string_ common names in
/// `self`.
///
/// Note that common names need not be UTF-8 strings, or strings at all.
/// This method filters the common names in `self` to those that are. Use
/// the raw [`iter_common_name()`](#method.iter_common_name) to iterate over
/// all value types.
///
/// # Example
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// use rocket::mtls::Certificate;
///
/// #[get("/auth")]
/// fn auth(cert: Certificate<'_>) {
/// for name in cert.issuer().common_names() {
/// println!("Issued by {}.", name);
/// }
/// }
/// ```
pub fn common_names(&self) -> impl Iterator<Item = &'a str> + '_ {
self.iter_by_oid(&oid::OID_X509_COMMON_NAME).filter_map(|n| n.as_str().ok())
}
/// Returns the _first_ UTF-8 _string_ email address, if any.
///
/// Note that email addresses need not be UTF-8 strings, or strings at all.
/// This method returns the first email address attribute that is.
///
/// # Example
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// use rocket::mtls::Certificate;
///
/// #[get("/auth")]
/// fn auth(cert: Certificate<'_>) {
/// if let Some(email) = cert.subject().email() {
/// println!("Hello, {}!", email);
/// }
/// }
/// ```
pub fn email(&self) -> Option<&'a str> {
self.emails().next()
}
/// Returns an iterator over all of the UTF-8 _string_ email addresses in
/// `self`.
///
/// Note that email addresses need not be UTF-8 strings, or strings at all.
/// This method filters the email address in `self` to those that are. Use
/// the raw [`iter_email()`](#method.iter_email) to iterate over all value
/// types.
///
/// # Example
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// use rocket::mtls::Certificate;
///
/// #[get("/auth")]
/// fn auth(cert: Certificate<'_>) {
/// for email in cert.subject().emails() {
/// println!("Reach me at: {}", email);
/// }
/// }
/// ```
pub fn emails(&self) -> impl Iterator<Item = &'a str> + '_ {
self.iter_by_oid(&oid::OID_PKCS9_EMAIL_ADDRESS).filter_map(|n| n.as_str().ok())
}
/// Returns `true` if `self` has no data.
///
/// When this is the case for a `subject()`, the subject data can be found
/// in the `subjectAlt` [`extension()`](Certificate::extensions()).
///
/// # Example
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// use rocket::mtls::Certificate;
///
/// #[get("/auth")]
/// fn auth(cert: Certificate<'_>) {
/// let no_data = cert.subject().is_empty();
/// }
/// ```
pub fn is_empty(&self) -> bool {
self.0.as_raw().is_empty()
}
}
impl<'a> Deref for Name<'a> {
type Target = X509Name<'a>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl fmt::Display for Name<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::Parse(e) => write!(f, "parse error: {}", e),
Error::Incomplete(_) => write!(f, "incomplete certificate data"),
Error::Trailing(n) => write!(f, "found {} trailing bytes", n),
Error::Empty => write!(f, "empty certificate chain"),
Error::NoSubject => write!(f, "empty subject without subjectAlt"),
Error::NonCriticalSubjectAlt => write!(f, "empty subject without critical subjectAlt"),
}
}
}
impl From<nom::Err<X509Error>> for Error {
fn from(e: nom::Err<X509Error>) -> Self {
match e {
nom::Err::Incomplete(nom::Needed::Unknown) => Error::Incomplete(None),
nom::Err::Incomplete(nom::Needed::Size(n)) => Error::Incomplete(Some(n)),
nom::Err::Error(e) | nom::Err::Failure(e) => Error::Parse(e),
}
}
}
impl std::error::Error for Error {
// fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
// match self {
// Error::Parse(e) => Some(e),
// _ => None
// }
// }
}

212
core/lib/src/mtls/config.rs Normal file
View File

@ -0,0 +1,212 @@
use std::io;
use figment::value::magic::{RelativePathBuf, Either};
use serde::{Serialize, Deserialize};
/// Mutual TLS configuration.
///
/// Configuration works in concert with the [`mtls`](crate::mtls) module, which
/// provides a request guard to validate, verify, and retrieve client
/// certificates in routes.
///
/// By default, mutual TLS is disabled and client certificates are not required,
/// validated or verified. To enable mutual TLS, the `mtls` feature must be
/// enabled and support configured via two `tls.mutual` parameters:
///
/// * `ca_certs`
///
/// A required path to a PEM file or raw bytes to a DER-encoded X.509 TLS
/// certificate chain for the certificate authority to verify client
/// certificates against. When a path is configured in a file, such as
/// `Rocket.toml`, relative paths are interpreted as relative to the source
/// file's directory.
///
/// * `mandatory`
///
/// An optional boolean that control whether client authentication is
/// required.
///
/// When `true`, client authentication is required. TLS connections where
/// the client does not present a certificate are immediately terminated.
/// When `false`, the client is not required to present a certificate. In
/// either case, if a certificate _is_ presented, it must be valid or the
/// connection is terminated.
///
/// In a `Rocket.toml`, configuration might look like:
///
/// ```toml
/// [default.tls.mutual]
/// ca_certs = "/ssl/ca_cert.pem"
/// mandatory = true # when absent, defaults to false
/// ```
///
/// Programmatically, configuration might look like:
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// use rocket::mtls::MtlsConfig;
/// use rocket::figment::providers::Serialized;
///
/// #[launch]
/// fn rocket() -> _ {
/// let mtls = MtlsConfig::from_path("/ssl/ca_cert.pem");
/// rocket::custom(rocket::Config::figment().merge(("tls.mutual", mtls)))
/// }
/// ```
///
/// Once mTLS is configured, the [`mtls::Certificate`](crate::mtls::Certificate)
/// request guard can be used to retrieve client certificates in routes.
#[derive(PartialEq, Debug, Clone, Deserialize, Serialize)]
pub struct MtlsConfig {
/// Path to a PEM file with, or raw bytes for, DER-encoded Certificate
/// Authority certificates which will be used to verify client-presented
/// certificates.
// TODO: Support more than one CA root.
pub(crate) ca_certs: Either<RelativePathBuf, Vec<u8>>,
/// Whether the client is required to present a certificate.
///
/// When `true`, the client is required to present a valid certificate to
/// proceed with TLS. When `false`, the client is not required to present a
/// certificate. In either case, if a certificate _is_ presented, it must be
/// valid or the connection is terminated.
#[serde(default)]
#[serde(deserialize_with = "figment::util::bool_from_str_or_int")]
pub mandatory: bool,
}
impl MtlsConfig {
/// Constructs a `MtlsConfig` from a path to a PEM file with a certificate
/// authority `ca_certs` DER-encoded X.509 TLS certificate chain. This
/// method does no validation; it simply creates a structure suitable for
/// passing into a [`TlsConfig`].
///
/// These certificates will be used to verify client-presented certificates
/// in TLS connections.
///
/// # Example
///
/// ```rust
/// use rocket::mtls::MtlsConfig;
///
/// let tls_config = MtlsConfig::from_path("/ssl/ca_certs.pem");
/// ```
pub fn from_path<C: AsRef<std::path::Path>>(ca_certs: C) -> Self {
MtlsConfig {
ca_certs: Either::Left(ca_certs.as_ref().to_path_buf().into()),
mandatory: Default::default()
}
}
/// Constructs a `MtlsConfig` from a byte buffer to a certificate authority
/// `ca_certs` DER-encoded X.509 TLS certificate chain. This method does no
/// validation; it simply creates a structure suitable for passing into a
/// [`TlsConfig`].
///
/// These certificates will be used to verify client-presented certificates
/// in TLS connections.
///
/// # Example
///
/// ```rust
/// use rocket::mtls::MtlsConfig;
///
/// # let ca_certs_buf = &[];
/// let mtls_config = MtlsConfig::from_bytes(ca_certs_buf);
/// ```
pub fn from_bytes(ca_certs: &[u8]) -> Self {
MtlsConfig {
ca_certs: Either::Right(ca_certs.to_vec()),
mandatory: Default::default()
}
}
/// Sets whether client authentication is required. Disabled by default.
///
/// When `true`, client authentication will be required. TLS connections
/// where the client does not present a certificate will be immediately
/// terminated. When `false`, the client is not required to present a
/// certificate. In either case, if a certificate _is_ presented, it must be
/// valid or the connection is terminated.
///
/// # Example
///
/// ```rust
/// use rocket::mtls::MtlsConfig;
///
/// # let ca_certs_buf = &[];
/// let mtls_config = MtlsConfig::from_bytes(ca_certs_buf).mandatory(true);
/// ```
pub fn mandatory(mut self, mandatory: bool) -> Self {
self.mandatory = mandatory;
self
}
/// Returns the value of the `ca_certs` parameter.
/// # Example
///
/// ```rust
/// use rocket::mtls::MtlsConfig;
///
/// # let ca_certs_buf = &[];
/// let mtls_config = MtlsConfig::from_bytes(ca_certs_buf).mandatory(true);
/// assert_eq!(mtls_config.ca_certs().unwrap_right(), ca_certs_buf);
/// ```
pub fn ca_certs(&self) -> either::Either<std::path::PathBuf, &[u8]> {
match &self.ca_certs {
Either::Left(path) => either::Either::Left(path.relative()),
Either::Right(bytes) => either::Either::Right(&bytes),
}
}
#[inline(always)]
pub fn ca_certs_reader(&self) -> io::Result<Box<dyn io::BufRead + Sync + Send>> {
crate::tls::config::to_reader(&self.ca_certs)
}
}
#[cfg(test)]
mod tests {
use std::path::Path;
use figment::{Figment, providers::{Toml, Format}};
use crate::mtls::MtlsConfig;
#[test]
fn test_mtls_config() {
figment::Jail::expect_with(|jail| {
jail.create_file("MTLS.toml", r#"
certs = "/ssl/cert.pem"
key = "/ssl/key.pem"
"#)?;
let figment = || Figment::from(Toml::file("MTLS.toml"));
figment().extract::<MtlsConfig>().expect_err("no ca");
jail.create_file("MTLS.toml", r#"
ca_certs = "/ssl/ca.pem"
"#)?;
let mtls: MtlsConfig = figment().extract()?;
assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem"));
assert!(!mtls.mandatory);
jail.create_file("MTLS.toml", r#"
ca_certs = "/ssl/ca.pem"
mandatory = true
"#)?;
let mtls: MtlsConfig = figment().extract()?;
assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem"));
assert!(mtls.mandatory);
jail.create_file("MTLS.toml", r#"
ca_certs = "relative/ca.pem"
"#)?;
let mtls: MtlsConfig = figment().extract()?;
assert_eq!(mtls.ca_certs().unwrap_left(), jail.directory().join("relative/ca.pem"));
Ok(())
});
}
}

View File

@ -0,0 +1,74 @@
use std::fmt;
use std::num::NonZeroUsize;
use crate::mtls::x509::{self, nom};
/// An error returned by the [`Certificate`] request guard.
///
/// To retrieve this error in a handler, use an `mtls::Result<Certificate>`
/// guard type:
///
/// ```rust
/// # extern crate rocket;
/// # use rocket::get;
/// use rocket::mtls::{self, Certificate};
///
/// #[get("/auth")]
/// fn auth(cert: mtls::Result<Certificate<'_>>) {
/// match cert {
/// Ok(cert) => { /* do something with the client cert */ },
/// Err(e) => { /* do something with the error */ },
/// }
/// }
/// ```
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Error {
/// The certificate chain presented by the client had no certificates.
Empty,
/// The certificate contained neither a subject nor a subjectAlt extension.
NoSubject,
/// There is no subject and the subjectAlt is not marked as critical.
NonCriticalSubjectAlt,
/// An error occurred while parsing the certificate.
Parse(x509::X509Error),
/// The certificate parsed partially but is incomplete.
///
/// If `Some(n)`, then `n` more bytes were expected. Otherwise, the number
/// of expected bytes is unknown.
Incomplete(Option<NonZeroUsize>),
/// The certificate contained `.0` bytes of trailing data.
Trailing(usize),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::Parse(e) => write!(f, "parse error: {}", e),
Error::Incomplete(_) => write!(f, "incomplete certificate data"),
Error::Trailing(n) => write!(f, "found {} trailing bytes", n),
Error::Empty => write!(f, "empty certificate chain"),
Error::NoSubject => write!(f, "empty subject without subjectAlt"),
Error::NonCriticalSubjectAlt => write!(f, "empty subject without critical subjectAlt"),
}
}
}
impl From<nom::Err<x509::X509Error>> for Error {
fn from(e: nom::Err<x509::X509Error>) -> Self {
match e {
nom::Err::Incomplete(nom::Needed::Unknown) => Error::Incomplete(None),
nom::Err::Incomplete(nom::Needed::Size(n)) => Error::Incomplete(Some(n)),
nom::Err::Error(e) | nom::Err::Failure(e) => Error::Parse(e),
}
}
}
impl std::error::Error for Error {
// fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
// match self {
// Error::Parse(e) => Some(e),
// _ => None
// }
// }
}

56
core/lib/src/mtls/mod.rs Normal file
View File

@ -0,0 +1,56 @@
//! Support for mutual TLS client certificates.
//!
//! For details on how to configure mutual TLS, see
//! [`MutualTls`](crate::config::MutualTls) and the [TLS
//! guide](https://rocket.rs/master/guide/configuration/#tls). See
//! [`Certificate`] for a request guard that validated, verifies, and retrieves
//! client certificates.
pub mod oid {
//! Lower-level OID types re-exported from
//! [`oid_registry`](https://docs.rs/oid-registry/0.4) and
//! [`der-parser`](https://docs.rs/der-parser/7).
pub use x509_parser::oid_registry::*;
pub use x509_parser::objects::*;
}
pub mod bigint {
//! Signed and unsigned big integer types re-exported from
//! [`num_bigint`](https://docs.rs/num-bigint/0.4).
pub use x509_parser::der_parser::num_bigint::*;
}
pub mod x509 {
//! Lower-level X.509 types re-exported from
//! [`x509_parser`](https://docs.rs/x509-parser/0.13).
//!
//! Lack of documentation is directly inherited from the source crate.
//! Prefer to use Rocket's wrappers when possible.
pub(crate) use x509_parser::nom;
pub use x509_parser::certificate::*;
pub use x509_parser::cri_attributes::*;
pub use x509_parser::error::*;
pub use x509_parser::extensions::*;
pub use x509_parser::revocation_list::*;
pub use x509_parser::time::*;
pub use x509_parser::x509::*;
pub use x509_parser::der_parser::der;
pub use x509_parser::der_parser::ber;
pub use x509_parser::traits::*;
}
mod certificate;
mod error;
mod name;
mod config;
pub use error::Error;
pub use name::Name;
pub use config::MtlsConfig;
pub use certificate::{Certificate, CertificateDer};
/// A type alias for [`Result`](std::result::Result) with the error type set to
/// [`Error`].
pub type Result<T, E = Error> = std::result::Result<T, E>;

146
core/lib/src/mtls/name.rs Normal file
View File

@ -0,0 +1,146 @@
use std::fmt;
use std::ops::Deref;
use ref_cast::RefCast;
use crate::mtls::x509::X509Name;
use crate::mtls::oid;
/// An X.509 Distinguished Name (DN) found in a [`Certificate`].
///
/// This type is a wrapper over [`x509::X509Name`] with convenient methods and
/// complete documentation. Should the data exposed by the inherent methods not
/// suffice, this type derefs to [`x509::X509Name`].
#[repr(transparent)]
#[derive(Debug, PartialEq, RefCast)]
pub struct Name<'a>(X509Name<'a>);
impl<'a> Name<'a> {
/// Returns the _first_ UTF-8 _string_ common name, if any.
///
/// Note that common names need not be UTF-8 strings, or strings at all.
/// This method returns the first common name attribute that is.
///
/// # Example
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// use rocket::mtls::Certificate;
///
/// #[get("/auth")]
/// fn auth(cert: Certificate<'_>) {
/// if let Some(name) = cert.subject().common_name() {
/// println!("Hello, {}!", name);
/// }
/// }
/// ```
pub fn common_name(&self) -> Option<&'a str> {
self.common_names().next()
}
/// Returns an iterator over all of the UTF-8 _string_ common names in
/// `self`.
///
/// Note that common names need not be UTF-8 strings, or strings at all.
/// This method filters the common names in `self` to those that are. Use
/// the raw [`iter_common_name()`](#method.iter_common_name) to iterate over
/// all value types.
///
/// # Example
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// use rocket::mtls::Certificate;
///
/// #[get("/auth")]
/// fn auth(cert: Certificate<'_>) {
/// for name in cert.issuer().common_names() {
/// println!("Issued by {}.", name);
/// }
/// }
/// ```
pub fn common_names(&self) -> impl Iterator<Item = &'a str> + '_ {
self.iter_by_oid(&oid::OID_X509_COMMON_NAME).filter_map(|n| n.as_str().ok())
}
/// Returns the _first_ UTF-8 _string_ email address, if any.
///
/// Note that email addresses need not be UTF-8 strings, or strings at all.
/// This method returns the first email address attribute that is.
///
/// # Example
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// use rocket::mtls::Certificate;
///
/// #[get("/auth")]
/// fn auth(cert: Certificate<'_>) {
/// if let Some(email) = cert.subject().email() {
/// println!("Hello, {}!", email);
/// }
/// }
/// ```
pub fn email(&self) -> Option<&'a str> {
self.emails().next()
}
/// Returns an iterator over all of the UTF-8 _string_ email addresses in
/// `self`.
///
/// Note that email addresses need not be UTF-8 strings, or strings at all.
/// This method filters the email address in `self` to those that are. Use
/// the raw [`iter_email()`](#method.iter_email) to iterate over all value
/// types.
///
/// # Example
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// use rocket::mtls::Certificate;
///
/// #[get("/auth")]
/// fn auth(cert: Certificate<'_>) {
/// for email in cert.subject().emails() {
/// println!("Reach me at: {}", email);
/// }
/// }
/// ```
pub fn emails(&self) -> impl Iterator<Item = &'a str> + '_ {
self.iter_by_oid(&oid::OID_PKCS9_EMAIL_ADDRESS).filter_map(|n| n.as_str().ok())
}
/// Returns `true` if `self` has no data.
///
/// When this is the case for a `subject()`, the subject data can be found
/// in the `subjectAlt` [`extension()`](Certificate::extensions()).
///
/// # Example
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// use rocket::mtls::Certificate;
///
/// #[get("/auth")]
/// fn auth(cert: Certificate<'_>) {
/// let no_data = cert.subject().is_empty();
/// }
/// ```
pub fn is_empty(&self) -> bool {
self.0.as_raw().is_empty()
}
}
impl<'a> Deref for Name<'a> {
type Target = X509Name<'a>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl fmt::Display for Name<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}

View File

@ -1,6 +1,7 @@
use state::TypeMap; use state::TypeMap;
use figment::Figment; use figment::Figment;
use crate::listener::Endpoint;
use crate::{Catcher, Config, Rocket, Route, Shutdown}; use crate::{Catcher, Config, Rocket, Route, Shutdown};
use crate::router::Router; use crate::router::Router;
use crate::fairing::Fairings; use crate::fairing::Fairings;
@ -113,5 +114,6 @@ phases! {
pub(crate) config: Config, pub(crate) config: Config,
pub(crate) state: TypeMap![Send + Sync], pub(crate) state: TypeMap![Send + Sync],
pub(crate) shutdown: Shutdown, pub(crate) shutdown: Shutdown,
pub(crate) endpoint: Endpoint,
} }
} }

View File

@ -0,0 +1,43 @@
use crate::http::Method;
pub struct AtomicMethod(ref_swap::RefSwap<'static, Method>);
#[inline(always)]
const fn makeref(method: Method) -> &'static Method {
match method {
Method::Get => &Method::Get,
Method::Put => &Method::Put,
Method::Post => &Method::Post,
Method::Delete => &Method::Delete,
Method::Options => &Method::Options,
Method::Head => &Method::Head,
Method::Trace => &Method::Trace,
Method::Connect => &Method::Connect,
Method::Patch => &Method::Patch,
}
}
impl AtomicMethod {
pub fn new(value: Method) -> Self {
Self(ref_swap::RefSwap::new(makeref(value)))
}
pub fn load(&self) -> Method {
*self.0.load(std::sync::atomic::Ordering::Acquire)
}
pub fn set(&mut self, new: Method) {
*self = Self::new(new);
}
pub fn store(&self, new: Method) {
self.0.store(makeref(new), std::sync::atomic::Ordering::Release)
}
}
impl Clone for AtomicMethod {
fn clone(&self) -> Self {
let inner = self.0.load(std::sync::atomic::Ordering::Acquire);
Self(ref_swap::RefSwap::new(inner))
}
}

View File

@ -1,12 +1,13 @@
use std::convert::Infallible; use std::convert::Infallible;
use std::fmt::Debug; use std::fmt::Debug;
use std::net::{IpAddr, SocketAddr}; use std::net::IpAddr;
use crate::{Request, Route}; use crate::{Request, Route};
use crate::outcome::{self, IntoOutcome, Outcome::*}; use crate::outcome::{self, IntoOutcome, Outcome::*};
use crate::http::uri::{Host, Origin}; use crate::http::uri::{Host, Origin};
use crate::http::{Status, ContentType, Accept, Method, ProxyProto, CookieJar}; use crate::http::{Status, ContentType, Accept, Method, ProxyProto, CookieJar};
use crate::listener::Endpoint;
/// Type alias for the `Outcome` of a `FromRequest` conversion. /// Type alias for the `Outcome` of a `FromRequest` conversion.
pub type Outcome<S, E> = outcome::Outcome<S, (Status, E), Status>; pub type Outcome<S, E> = outcome::Outcome<S, (Status, E), Status>;
@ -486,14 +487,22 @@ impl<'r> FromRequest<'r> for ProxyProto<'r> {
} }
#[crate::async_trait] #[crate::async_trait]
impl<'r> FromRequest<'r> for SocketAddr { impl<'r> FromRequest<'r> for &'r Endpoint {
type Error = Infallible; type Error = Infallible;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Infallible> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Infallible> {
match request.remote() { request.remote().or_forward(Status::InternalServerError)
Some(addr) => Success(addr),
None => Forward(Status::InternalServerError)
} }
}
#[crate::async_trait]
impl<'r> FromRequest<'r> for std::net::SocketAddr {
type Error = Infallible;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Infallible> {
request.remote()
.and_then(|r| r.tcp())
.or_forward(Status::InternalServerError)
} }
} }

View File

@ -3,6 +3,7 @@
mod request; mod request;
mod from_param; mod from_param;
mod from_request; mod from_request;
mod atomic_method;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
@ -15,6 +16,7 @@ pub use self::from_param::{FromParam, FromSegments};
pub use crate::response::flash::FlashMessage; pub use crate::response::flash::FlashMessage;
pub(crate) use self::request::ConnectionMeta; pub(crate) use self::request::ConnectionMeta;
pub(crate) use self::atomic_method::AtomicMethod;
crate::export! { crate::export! {
/// Store and immediately retrieve a vector-like value `$v` (`String` or /// Store and immediately retrieve a vector-like value `$v` (`String` or

View File

@ -1,22 +1,24 @@
use std::fmt; use std::fmt;
use std::ops::RangeFrom; use std::ops::RangeFrom;
use std::{future::Future, borrow::Cow, sync::Arc}; use std::sync::{Arc, atomic::Ordering};
use std::net::{IpAddr, SocketAddr}; use std::borrow::Cow;
use std::future::Future;
use std::net::IpAddr;
use yansi::Paint; use yansi::Paint;
use state::{TypeMap, InitCell}; use state::{TypeMap, InitCell};
use futures::future::BoxFuture; use futures::future::BoxFuture;
use atomic::{Atomic, Ordering}; use ref_swap::OptionRefSwap;
use crate::{Rocket, Route, Orbit}; use crate::{Rocket, Route, Orbit};
use crate::request::{FromParam, FromSegments, FromRequest, Outcome}; use crate::request::{FromParam, FromSegments, FromRequest, Outcome, AtomicMethod};
use crate::form::{self, ValueField, FromForm}; use crate::form::{self, ValueField, FromForm};
use crate::data::Limits; use crate::data::Limits;
use crate::http::{hyper, Method, Header, HeaderMap, ProxyProto}; use crate::http::ProxyProto;
use crate::http::{ContentType, Accept, MediaType, CookieJar, Cookie}; use crate::http::{Method, Header, HeaderMap, ContentType, Accept, MediaType, CookieJar, Cookie};
use crate::http::private::Certificates;
use crate::http::uri::{fmt::Path, Origin, Segments, Host, Authority}; use crate::http::uri::{fmt::Path, Origin, Segments, Host, Authority};
use crate::listener::{Certificates, Endpoint, Connection};
/// The type of an incoming web request. /// The type of an incoming web request.
/// ///
@ -24,26 +26,37 @@ use crate::http::uri::{fmt::Path, Origin, Segments, Host, Authority};
/// should likely only be used when writing [`FromRequest`] implementations. It /// should likely only be used when writing [`FromRequest`] implementations. It
/// contains all of the information for a given web request except for the body /// contains all of the information for a given web request except for the body
/// data. This includes the HTTP method, URI, cookies, headers, and more. /// data. This includes the HTTP method, URI, cookies, headers, and more.
#[derive(Clone)]
pub struct Request<'r> { pub struct Request<'r> {
method: Atomic<Method>, method: AtomicMethod,
uri: Origin<'r>, uri: Origin<'r>,
headers: HeaderMap<'r>, headers: HeaderMap<'r>,
pub(crate) errors: Vec<RequestError>,
pub(crate) connection: ConnectionMeta, pub(crate) connection: ConnectionMeta,
pub(crate) state: RequestState<'r>, pub(crate) state: RequestState<'r>,
} }
/// Information derived from an incoming connection, if any. /// Information derived from an incoming connection, if any.
#[derive(Clone)] #[derive(Clone, Default)]
pub(crate) struct ConnectionMeta { pub(crate) struct ConnectionMeta {
pub remote: Option<SocketAddr>, pub peer_address: Option<Arc<Endpoint>>,
#[cfg_attr(not(feature = "mtls"), allow(dead_code))] #[cfg_attr(not(feature = "mtls"), allow(dead_code))]
pub client_certificates: Option<Certificates>, pub peer_certs: Option<Arc<Certificates<'static>>>,
}
impl<C: Connection> From<&C> for ConnectionMeta {
fn from(conn: &C) -> Self {
ConnectionMeta {
peer_address: conn.peer_address().ok().map(Arc::new),
peer_certs: conn.peer_certificates().map(|c| c.into_owned()).map(Arc::new),
}
}
} }
/// Information derived from the request. /// Information derived from the request.
pub(crate) struct RequestState<'r> { pub(crate) struct RequestState<'r> {
pub rocket: &'r Rocket<Orbit>, pub rocket: &'r Rocket<Orbit>,
pub route: Atomic<Option<&'r Route>>, pub route: OptionRefSwap<'r, Route>,
pub cookies: CookieJar<'r>, pub cookies: CookieJar<'r>,
pub accept: InitCell<Option<Accept>>, pub accept: InitCell<Option<Accept>>,
pub content_type: InitCell<Option<ContentType>>, pub content_type: InitCell<Option<ContentType>>,
@ -51,23 +64,11 @@ pub(crate) struct RequestState<'r> {
pub host: Option<Host<'r>>, pub host: Option<Host<'r>>,
} }
impl Request<'_> { impl Clone for RequestState<'_> {
pub(crate) fn clone(&self) -> Self {
Request {
method: Atomic::new(self.method()),
uri: self.uri.clone(),
headers: self.headers.clone(),
connection: self.connection.clone(),
state: self.state.clone(),
}
}
}
impl RequestState<'_> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
RequestState { RequestState {
rocket: self.rocket, rocket: self.rocket,
route: Atomic::new(self.route.load(Ordering::Acquire)), route: OptionRefSwap::new(self.route.load(Ordering::Acquire)),
cookies: self.cookies.clone(), cookies: self.cookies.clone(),
accept: self.accept.clone(), accept: self.accept.clone(),
content_type: self.content_type.clone(), content_type: self.content_type.clone(),
@ -87,15 +88,13 @@ impl<'r> Request<'r> {
) -> Request<'r> { ) -> Request<'r> {
Request { Request {
uri, uri,
method: Atomic::new(method), method: AtomicMethod::new(method),
headers: HeaderMap::new(), headers: HeaderMap::new(),
connection: ConnectionMeta { errors: Vec::new(),
remote: None, connection: ConnectionMeta::default(),
client_certificates: None,
},
state: RequestState { state: RequestState {
rocket, rocket,
route: Atomic::new(None), route: OptionRefSwap::new(None),
cookies: CookieJar::new(None, rocket), cookies: CookieJar::new(None, rocket),
accept: InitCell::new(), accept: InitCell::new(),
content_type: InitCell::new(), content_type: InitCell::new(),
@ -120,7 +119,7 @@ impl<'r> Request<'r> {
/// ``` /// ```
#[inline(always)] #[inline(always)]
pub fn method(&self) -> Method { pub fn method(&self) -> Method {
self.method.load(Ordering::Acquire) self.method.load()
} }
/// Set the method of `self` to `method`. /// Set the method of `self` to `method`.
@ -140,7 +139,7 @@ impl<'r> Request<'r> {
/// ``` /// ```
#[inline(always)] #[inline(always)]
pub fn set_method(&mut self, method: Method) { pub fn set_method(&mut self, method: Method) {
self._set_method(method); self.method.set(method);
} }
/// Borrow the [`Origin`] URI from `self`. /// Borrow the [`Origin`] URI from `self`.
@ -324,20 +323,20 @@ impl<'r> Request<'r> {
/// ///
/// assert_eq!(request.remote(), None); /// assert_eq!(request.remote(), None);
/// ///
/// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8000).into(); /// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8000);
/// request.set_remote(localhost); /// request.set_remote(localhost);
/// assert_eq!(request.remote(), Some(localhost)); /// assert_eq!(request.remote().unwrap(), &localhost);
/// ``` /// ```
#[inline(always)] #[inline(always)]
pub fn remote(&self) -> Option<SocketAddr> { pub fn remote(&self) -> Option<&Endpoint> {
self.connection.remote self.connection.peer_address.as_deref()
} }
/// Sets the remote address of `self` to `address`. /// Sets the remote address of `self` to `address`.
/// ///
/// # Example /// # Example
/// ///
/// Set the remote address to be 127.0.0.1:8000: /// Set the remote address to be 127.0.0.1:8111:
/// ///
/// ```rust /// ```rust
/// use std::net::{SocketAddrV4, Ipv4Addr}; /// use std::net::{SocketAddrV4, Ipv4Addr};
@ -347,13 +346,13 @@ impl<'r> Request<'r> {
/// ///
/// assert_eq!(request.remote(), None); /// assert_eq!(request.remote(), None);
/// ///
/// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8000).into(); /// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8111);
/// request.set_remote(localhost); /// request.set_remote(localhost);
/// assert_eq!(request.remote(), Some(localhost)); /// assert_eq!(request.remote().unwrap(), &localhost);
/// ``` /// ```
#[inline(always)] #[inline(always)]
pub fn set_remote(&mut self, address: SocketAddr) { pub fn set_remote<A: Into<Endpoint>>(&mut self, address: A) {
self.connection.remote = Some(address); self.connection.peer_address = Some(Arc::new(address.into()));
} }
/// Returns the IP address of the configured /// Returns the IP address of the configured
@ -489,25 +488,26 @@ impl<'r> Request<'r> {
/// ///
/// ```rust /// ```rust
/// # use rocket::http::Header; /// # use rocket::http::Header;
/// # use std::net::{SocketAddr, IpAddr, Ipv4Addr};
/// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap(); /// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap();
/// # let mut req = c.get("/"); /// # let mut req = c.get("/");
/// # let request = req.inner_mut(); /// # let request = req.inner_mut();
/// # use std::net::{SocketAddrV4, Ipv4Addr};
/// ///
/// // starting without an "X-Real-IP" header or remote address /// // starting without an "X-Real-IP" header or remote address
/// assert!(request.client_ip().is_none()); /// assert!(request.client_ip().is_none());
/// ///
/// // add a remote address; this is done by Rocket automatically /// // add a remote address; this is done by Rocket automatically
/// request.set_remote("127.0.0.1:8000".parse().unwrap()); /// let localhost_9190 = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 9190);
/// assert_eq!(request.client_ip(), Some("127.0.0.1".parse().unwrap())); /// request.set_remote(localhost_9190);
/// assert_eq!(request.client_ip().unwrap(), Ipv4Addr::LOCALHOST);
/// ///
/// // now with an X-Real-IP header, the default value for `ip_header`. /// // now with an X-Real-IP header, the default value for `ip_header`.
/// request.add_header(Header::new("X-Real-IP", "8.8.8.8")); /// request.add_header(Header::new("X-Real-IP", "8.8.8.8"));
/// assert_eq!(request.client_ip(), Some("8.8.8.8".parse().unwrap())); /// assert_eq!(request.client_ip().unwrap(), Ipv4Addr::new(8, 8, 8, 8));
/// ``` /// ```
#[inline] #[inline]
pub fn client_ip(&self) -> Option<IpAddr> { pub fn client_ip(&self) -> Option<IpAddr> {
self.real_ip().or_else(|| self.remote().map(|r| r.ip())) self.real_ip().or_else(|| Some(self.remote()?.tcp()?.ip()))
} }
/// Returns a wrapped borrow to the cookies in `self`. /// Returns a wrapped borrow to the cookies in `self`.
@ -691,7 +691,7 @@ impl<'r> Request<'r> {
if self.method().supports_payload() { if self.method().supports_payload() {
self.content_type().map(|ct| ct.media_type()) self.content_type().map(|ct| ct.media_type())
} else { } else {
// FIXME: Should we be using `accept_first` or `preferred`? Or // TODO: Should we be using `accept_first` or `preferred`? Or
// should we be checking neither and instead pass things through // should we be checking neither and instead pass things through
// where the client accepts the thing at all? // where the client accepts the thing at all?
self.accept() self.accept()
@ -1056,11 +1056,9 @@ impl<'r> Request<'r> {
self.state.route.store(Some(route), Ordering::Release) self.state.route.store(Some(route), Ordering::Release)
} }
/// Set the method of `self`, even when `self` is a shared reference. Used
/// during routing to override methods for re-routing.
#[inline(always)] #[inline(always)]
pub(crate) fn _set_method(&self, method: Method) { pub(crate) fn _set_method(&self, method: Method) {
self.method.store(method, Ordering::Release) self.method.store(method)
} }
pub(crate) fn cookies_mut(&mut self) -> &mut CookieJar<'r> { pub(crate) fn cookies_mut(&mut self) -> &mut CookieJar<'r> {
@ -1070,18 +1068,28 @@ impl<'r> Request<'r> {
/// Convert from Hyper types into a Rocket Request. /// Convert from Hyper types into a Rocket Request.
pub(crate) fn from_hyp( pub(crate) fn from_hyp(
rocket: &'r Rocket<Orbit>, rocket: &'r Rocket<Orbit>,
hyper: &'r hyper::request::Parts, hyper: &'r hyper::http::request::Parts,
connection: Option<ConnectionMeta>, connection: ConnectionMeta,
) -> Result<Request<'r>, BadRequest<'r>> { ) -> Result<Request<'r>, Request<'r>> {
// Keep track of parsing errors; emit a `BadRequest` if any exist. // Keep track of parsing errors; emit a `BadRequest` if any exist.
let mut errors = vec![]; let mut errors = vec![];
// Ensure that the method is known. TODO: Allow made-up methods? // Ensure that the method is known. TODO: Allow made-up methods?
let method = Method::from_hyp(&hyper.method) let method = match hyper.method {
.unwrap_or_else(|| { hyper::Method::GET => Method::Get,
errors.push(Kind::BadMethod(&hyper.method)); hyper::Method::PUT => Method::Put,
hyper::Method::POST => Method::Post,
hyper::Method::DELETE => Method::Delete,
hyper::Method::OPTIONS => Method::Options,
hyper::Method::HEAD => Method::Head,
hyper::Method::TRACE => Method::Trace,
hyper::Method::CONNECT => Method::Connect,
hyper::Method::PATCH => Method::Patch,
_ => {
errors.push(RequestError::BadMethod(hyper.method.clone()));
Method::Get Method::Get
}); }
};
// TODO: Keep around not just the path/query, but the rest, if there? // TODO: Keep around not just the path/query, but the rest, if there?
let uri = hyper.uri.path_and_query() let uri = hyper.uri.path_and_query()
@ -1100,20 +1108,20 @@ impl<'r> Request<'r> {
Origin::new(uri.path(), uri.query().map(Cow::Borrowed)) Origin::new(uri.path(), uri.query().map(Cow::Borrowed))
}) })
.unwrap_or_else(|| { .unwrap_or_else(|| {
errors.push(Kind::InvalidUri(&hyper.uri)); errors.push(RequestError::InvalidUri(hyper.uri.clone()));
Origin::ROOT Origin::ROOT
}); });
// Construct the request object; fill in metadata and headers next. // Construct the request object; fill in metadata and headers next.
let mut request = Request::new(rocket, method, uri); let mut request = Request::new(rocket, method, uri);
request.errors = errors;
// Set the passed in connection metadata. // Set the passed in connection metadata.
if let Some(connection) = connection {
request.connection = connection; request.connection = connection;
}
// Determine + set host. On HTTP < 2, use the `HOST` header. Otherwise, // Determine + set host. On HTTP < 2, use the `HOST` header. Otherwise,
// use the `:authority` pseudo-header which hyper makes part of the URI. // use the `:authority` pseudo-header which hyper makes part of the URI.
// TODO: Use an `InitCell` to compute this later.
request.state.host = if hyper.version < hyper::Version::HTTP_2 { request.state.host = if hyper.version < hyper::Version::HTTP_2 {
hyper.headers.get("host").and_then(|h| Host::parse_bytes(h.as_bytes()).ok()) hyper.headers.get("host").and_then(|h| Host::parse_bytes(h.as_bytes()).ok())
} else { } else {
@ -1122,9 +1130,8 @@ impl<'r> Request<'r> {
// Set the request cookies, if they exist. // Set the request cookies, if they exist.
for header in hyper.headers.get_all("Cookie") { for header in hyper.headers.get_all("Cookie") {
let raw_str = match std::str::from_utf8(header.as_bytes()) { let Ok(raw_str) = std::str::from_utf8(header.as_bytes()) else {
Ok(string) => string, continue
Err(_) => continue
}; };
for cookie_str in raw_str.split(';').map(|s| s.trim()) { for cookie_str in raw_str.split(';').map(|s| s.trim()) {
@ -1137,43 +1144,33 @@ impl<'r> Request<'r> {
// Set the rest of the headers. This is rather unfortunate and slow. // Set the rest of the headers. This is rather unfortunate and slow.
for (name, value) in hyper.headers.iter() { for (name, value) in hyper.headers.iter() {
// FIXME: This is rather unfortunate. Header values needn't be UTF8. // FIXME: This is rather unfortunate. Header values needn't be UTF8.
let value = match std::str::from_utf8(value.as_bytes()) { let Ok(value) = std::str::from_utf8(value.as_bytes()) else {
Ok(value) => value,
Err(_) => {
warn!("Header '{}' contains invalid UTF-8", name); warn!("Header '{}' contains invalid UTF-8", name);
warn_!("Rocket only supports UTF-8 header values. Dropping header."); warn_!("Rocket only supports UTF-8 header values. Dropping header.");
continue; continue;
}
}; };
request.add_header(Header::new(name.as_str(), value)); request.add_header(Header::new(name.as_str(), value));
} }
if errors.is_empty() { match request.errors.is_empty() {
Ok(request) true => Ok(request),
} else { false => Err(request),
Err(BadRequest { request, errors })
} }
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub(crate) struct BadRequest<'r> { pub(crate) enum RequestError {
pub request: Request<'r>, InvalidUri(hyper::Uri),
pub errors: Vec<Kind<'r>>, BadMethod(hyper::Method),
} }
#[derive(Debug)] impl fmt::Display for RequestError {
pub(crate) enum Kind<'r> {
InvalidUri(&'r hyper::Uri),
BadMethod(&'r hyper::Method),
}
impl fmt::Display for Kind<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
Kind::InvalidUri(u) => write!(f, "invalid origin URI: {}", u), RequestError::InvalidUri(u) => write!(f, "invalid origin URI: {}", u),
Kind::BadMethod(m) => write!(f, "invalid or unrecognized method: {}", m), RequestError::BadMethod(m) => write!(f, "invalid or unrecognized method: {}", m),
} }
} }
} }
@ -1181,8 +1178,8 @@ impl fmt::Display for Kind<'_> {
impl fmt::Debug for Request<'_> { impl fmt::Debug for Request<'_> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Request") fmt.debug_struct("Request")
.field("method", &self.method) .field("method", &self.method())
.field("uri", &self.uri) .field("uri", &self.uri())
.field("headers", &self.headers()) .field("headers", &self.headers())
.field("remote", &self.remote()) .field("remote", &self.remote())
.field("cookies", &self.cookies()) .field("cookies", &self.cookies())

View File

@ -1,14 +1,16 @@
use std::collections::HashMap; use std::collections::HashMap;
use crate::Request; use crate::request::{Request, ConnectionMeta};
use crate::local::blocking::Client; use crate::local::blocking::Client;
use crate::http::hyper;
macro_rules! assert_headers { macro_rules! assert_headers {
($($key:expr => [$($value:expr),+]),+) => ({ ($($key:expr => [$($value:expr),+]),+) => ({
// Create a new Hyper request. Add all of the passed in headers. // Create a new Hyper request. Add all of the passed in headers.
let mut req = hyper::Request::get("/test").body(()).unwrap(); let mut req = hyper::Request::get("/test").body(()).unwrap();
$($(req.headers_mut().append($key, hyper::HeaderValue::from_str($value).unwrap());)+)+ $($(
req.headers_mut()
.append($key, hyper::header::HeaderValue::from_str($value).unwrap());
)+)+
// Build up what we expect the headers to actually be. // Build up what we expect the headers to actually be.
let mut expected = HashMap::new(); let mut expected = HashMap::new();
@ -17,7 +19,8 @@ macro_rules! assert_headers {
// Create a valid `Rocket` and convert the hyper req to a Rocket one. // Create a valid `Rocket` and convert the hyper req to a Rocket one.
let client = Client::debug_with(vec![]).unwrap(); let client = Client::debug_with(vec![]).unwrap();
let hyper = req.into_parts().0; let hyper = req.into_parts().0;
let req = Request::from_hyp(client.rocket(), &hyper, None).unwrap(); let meta = ConnectionMeta::default();
let req = Request::from_hyp(client.rocket(), &hyper, meta).unwrap();
// Dispatch the request and check that the headers match. // Dispatch the request and check that the headers match.
let actual_headers = req.headers(); let actual_headers = req.headers();

View File

@ -1,7 +1,6 @@
use std::{fmt, str}; use std::{fmt, str};
use std::borrow::Cow; use std::borrow::Cow;
use std::collections::HashMap; use std::collections::HashMap;
use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncSeek}; use tokio::io::{AsyncRead, AsyncSeek};
@ -146,19 +145,18 @@ impl<'r> Builder<'r> {
/// potentially different values to be present in the `Response`. /// potentially different values to be present in the `Response`.
/// ///
/// The type of `header` can be any type that implements `Into<Header>`. /// The type of `header` can be any type that implements `Into<Header>`.
/// This includes `Header` itself, [`ContentType`](crate::http::ContentType) and /// This includes `Header` itself, [`ContentType`](crate::http::ContentType)
/// [hyper::header types](crate::http::hyper::header). /// and [`Accept`](crate::http::Accept).
/// ///
/// # Example /// # Example
/// ///
/// ```rust /// ```rust
/// use rocket::Response; /// use rocket::Response;
/// use rocket::http::Header; /// use rocket::http::{Header, Accept};
/// use rocket::http::hyper::header::ACCEPT;
/// ///
/// let response = Response::build() /// let response = Response::build()
/// .header_adjoin(Header::new(ACCEPT.as_str(), "application/json")) /// .header_adjoin(Header::new("Accept", "application/json"))
/// .header_adjoin(Header::new(ACCEPT.as_str(), "text/plain")) /// .header_adjoin(Accept::XML)
/// .finalize(); /// .finalize();
/// ///
/// assert_eq!(response.headers().get("Accept").count(), 2); /// assert_eq!(response.headers().get("Accept").count(), 2);
@ -287,7 +285,7 @@ impl<'r> Builder<'r> {
/// ///
/// #[rocket::async_trait] /// #[rocket::async_trait]
/// impl IoHandler for EchoHandler { /// impl IoHandler for EchoHandler {
/// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> { /// async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
/// let (mut reader, mut writer) = io::split(io); /// let (mut reader, mut writer) = io::split(io);
/// io::copy(&mut reader, &mut writer).await?; /// io::copy(&mut reader, &mut writer).await?;
/// Ok(()) /// Ok(())
@ -488,7 +486,7 @@ pub struct Response<'r> {
status: Option<Status>, status: Option<Status>,
headers: HeaderMap<'r>, headers: HeaderMap<'r>,
body: Body<'r>, body: Body<'r>,
upgrade: HashMap<Uncased<'r>, Pin<Box<dyn IoHandler + 'r>>>, upgrade: HashMap<Uncased<'r>, Box<dyn IoHandler + 'r>>,
} }
impl<'r> Response<'r> { impl<'r> Response<'r> {
@ -700,23 +698,22 @@ impl<'r> Response<'r> {
/// name `header.name`, another header with the same name and value /// name `header.name`, another header with the same name and value
/// `header.value` is added. The type of `header` can be any type that /// `header.value` is added. The type of `header` can be any type that
/// implements `Into<Header>`. This includes `Header` itself, /// implements `Into<Header>`. This includes `Header` itself,
/// [`ContentType`](crate::http::ContentType) and [`hyper::header` /// [`ContentType`](crate::http::ContentType),
/// types](crate::http::hyper::header). /// [`Accept`](crate::http::Accept).
/// ///
/// # Example /// # Example
/// ///
/// ```rust /// ```rust
/// use rocket::Response; /// use rocket::Response;
/// use rocket::http::Header; /// use rocket::http::{Header, Accept};
/// use rocket::http::hyper::header::ACCEPT;
/// ///
/// let mut response = Response::new(); /// let mut response = Response::new();
/// response.adjoin_header(Header::new(ACCEPT.as_str(), "application/json")); /// response.adjoin_header(Accept::JSON);
/// response.adjoin_header(Header::new(ACCEPT.as_str(), "text/plain")); /// response.adjoin_header(Header::new("Accept", "text/plain"));
/// ///
/// let mut accept_headers = response.headers().iter(); /// let mut accept_headers = response.headers().iter();
/// assert_eq!(accept_headers.next(), Some(Header::new(ACCEPT.as_str(), "application/json"))); /// assert_eq!(accept_headers.next(), Some(Header::new("Accept", "application/json")));
/// assert_eq!(accept_headers.next(), Some(Header::new(ACCEPT.as_str(), "text/plain"))); /// assert_eq!(accept_headers.next(), Some(Header::new("Accept", "text/plain")));
/// assert_eq!(accept_headers.next(), None); /// assert_eq!(accept_headers.next(), None);
/// ``` /// ```
#[inline(always)] #[inline(always)]
@ -801,10 +798,10 @@ impl<'r> Response<'r> {
/// the comma-separated protocols any of the strings in `I`. Returns /// the comma-separated protocols any of the strings in `I`. Returns
/// `Ok(None)` if `self` doesn't support any kind of upgrade. Returns /// `Ok(None)` if `self` doesn't support any kind of upgrade. Returns
/// `Err(_)` if `protocols` is non-empty but no match was found in `self`. /// `Err(_)` if `protocols` is non-empty but no match was found in `self`.
pub(crate) fn take_upgrade<I: Iterator<Item = &'r str>>( pub(crate) fn search_upgrades<'a, I: Iterator<Item = &'a str>>(
&mut self, &mut self,
protocols: I protocols: I
) -> Result<Option<(Uncased<'r>, Pin<Box<dyn IoHandler + 'r>>)>, ()> { ) -> Result<Option<(Uncased<'r>, Box<dyn IoHandler + 'r>)>, ()> {
if self.upgrade.is_empty() { if self.upgrade.is_empty() {
return Ok(None); return Ok(None);
} }
@ -839,7 +836,7 @@ impl<'r> Response<'r> {
/// ///
/// #[rocket::async_trait] /// #[rocket::async_trait]
/// impl IoHandler for EchoHandler { /// impl IoHandler for EchoHandler {
/// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> { /// async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
/// let (mut reader, mut writer) = io::split(io); /// let (mut reader, mut writer) = io::split(io);
/// io::copy(&mut reader, &mut writer).await?; /// io::copy(&mut reader, &mut writer).await?;
/// Ok(()) /// Ok(())
@ -854,7 +851,7 @@ impl<'r> Response<'r> {
/// assert!(response.upgrade("raw-echo").is_some()); /// assert!(response.upgrade("raw-echo").is_some());
/// # }) /// # })
/// ``` /// ```
pub fn upgrade(&mut self, proto: &str) -> Option<Pin<&mut (dyn IoHandler + 'r)>> { pub fn upgrade(&mut self, proto: &str) -> Option<&mut (dyn IoHandler + 'r)> {
self.upgrade.get_mut(proto.as_uncased()).map(|h| h.as_mut()) self.upgrade.get_mut(proto.as_uncased()).map(|h| h.as_mut())
} }
@ -972,7 +969,7 @@ impl<'r> Response<'r> {
/// ///
/// #[rocket::async_trait] /// #[rocket::async_trait]
/// impl IoHandler for EchoHandler { /// impl IoHandler for EchoHandler {
/// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> { /// async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
/// let (mut reader, mut writer) = io::split(io); /// let (mut reader, mut writer) = io::split(io);
/// io::copy(&mut reader, &mut writer).await?; /// io::copy(&mut reader, &mut writer).await?;
/// Ok(()) /// Ok(())
@ -990,7 +987,7 @@ impl<'r> Response<'r> {
pub fn add_upgrade<N, H>(&mut self, protocol: N, handler: H) pub fn add_upgrade<N, H>(&mut self, protocol: N, handler: H)
where N: Into<Uncased<'r>>, H: IoHandler + 'r where N: Into<Uncased<'r>>, H: IoHandler + 'r
{ {
self.upgrade.insert(protocol.into(), Box::pin(handler)); self.upgrade.insert(protocol.into(), Box::new(handler));
} }
/// Sets the body's maximum chunk size to `size` bytes. /// Sets the body's maximum chunk size to `size` bytes.

View File

@ -1,9 +1,9 @@
use std::borrow::Cow; use std::borrow::Cow;
use tokio::io::AsyncRead; use tokio::io::AsyncRead;
use tokio::time::Duration; use tokio::time::{interval, Duration};
use futures::stream::{self, Stream, StreamExt}; use futures::{stream::{self, Stream}, future::Either};
use futures::future::ready; use tokio_stream::{StreamExt, wrappers::IntervalStream};
use crate::request::Request; use crate::request::Request;
use crate::response::{self, Response, Responder, stream::{ReaderStream, RawLinedEvent}}; use crate::response::{self, Response, Responder, stream::{ReaderStream, RawLinedEvent}};
@ -336,7 +336,7 @@ impl Event {
Some(RawLinedEvent::raw("")), Some(RawLinedEvent::raw("")),
]; ];
stream::iter(events).filter_map(ready) stream::iter(events).filter_map(|x| x)
} }
} }
@ -528,25 +528,19 @@ impl<S: Stream<Item = Event>> EventStream<S> {
self self
} }
fn heartbeat_stream(&self) -> Option<impl Stream<Item = RawLinedEvent>> { fn heartbeat_stream(&self) -> impl Stream<Item = RawLinedEvent> {
use tokio::time::interval;
use tokio_stream::wrappers::IntervalStream;
self.heartbeat self.heartbeat
.map(|beat| IntervalStream::new(interval(beat))) .map(|beat| IntervalStream::new(interval(beat)))
.map(|stream| stream.map(|_| RawLinedEvent::raw(":"))) .map(|stream| stream.map(|_| RawLinedEvent::raw(":")))
.map_or_else(|| Either::Right(stream::empty()), Either::Left)
} }
fn into_stream(self) -> impl Stream<Item = RawLinedEvent> { fn into_stream(self) -> impl Stream<Item = RawLinedEvent> {
use futures::future::Either; use futures::StreamExt;
use crate::ext::StreamExt;
let heartbeat_stream = self.heartbeat_stream(); let heartbeats = self.heartbeat_stream();
let raw_events = self.stream.map(|e| e.into_stream()).flatten(); let events = StreamExt::map(self.stream, |e| e.into_stream()).flatten();
match heartbeat_stream { crate::util::join(events, heartbeats)
Some(heartbeat) => Either::Left(raw_events.join(heartbeat)),
None => Either::Right(raw_events)
}
} }
fn into_reader(self) -> impl AsyncRead { fn into_reader(self) -> impl AsyncRead {
@ -621,10 +615,11 @@ mod sse_tests {
impl<S: Stream<Item = Event>> EventStream<S> { impl<S: Stream<Item = Event>> EventStream<S> {
fn into_string(self) -> String { fn into_string(self) -> String {
use std::pin::pin;
crate::async_test(async move { crate::async_test(async move {
let mut string = String::new(); let mut string = String::new();
let reader = self.into_reader(); let mut reader = pin!(self.into_reader());
tokio::pin!(reader);
reader.read_to_string(&mut string).await.expect("event stream -> string"); reader.read_to_string(&mut string).await.expect("event stream -> string");
string string
}) })

View File

@ -1,14 +1,14 @@
use std::fmt; use std::fmt;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use std::net::SocketAddr;
use yansi::Paint; use yansi::Paint;
use either::Either; use either::Either;
use figment::{Figment, Provider}; use figment::{Figment, Provider};
use crate::{Catcher, Config, Route, Shutdown, sentinel, shield::Shield}; use crate::{Catcher, Config, Route, Shutdown, sentinel, shield::Shield};
use crate::listener::{Endpoint, Bindable, DefaultListener};
use crate::router::Router; use crate::router::Router;
use crate::trip_wire::TripWire; use crate::util::TripWire;
use crate::fairing::{Fairing, Fairings}; use crate::fairing::{Fairing, Fairings};
use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting}; use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting};
use crate::phase::{Stateful, StateRef, State}; use crate::phase::{Stateful, StateRef, State};
@ -203,35 +203,31 @@ impl Rocket<Build> {
/// # Example /// # Example
/// ///
/// ```rust /// ```rust
/// use rocket::Config; /// use rocket::config::{Config, Ident};
/// # use std::net::Ipv4Addr; /// # use std::net::Ipv4Addr;
/// # use std::path::{Path, PathBuf}; /// # use std::path::{Path, PathBuf};
/// # type Result = std::result::Result<(), rocket::Error>; /// # type Result = std::result::Result<(), rocket::Error>;
/// ///
/// let config = Config { /// let config = Config {
/// port: 7777, /// ident: Ident::try_new("MyServer").expect("valid ident"),
/// address: Ipv4Addr::new(18, 127, 0, 1).into(),
/// temp_dir: "/tmp/config-example".into(), /// temp_dir: "/tmp/config-example".into(),
/// ..Config::debug_default() /// ..Config::debug_default()
/// }; /// };
/// ///
/// # let _: Result = rocket::async_test(async move { /// # let _: Result = rocket::async_test(async move {
/// let rocket = rocket::custom(&config).ignite().await?; /// let rocket = rocket::custom(&config).ignite().await?;
/// assert_eq!(rocket.config().port, 7777); /// assert_eq!(rocket.config().ident.as_str(), Some("MyServer"));
/// assert_eq!(rocket.config().address, Ipv4Addr::new(18, 127, 0, 1));
/// assert_eq!(rocket.config().temp_dir.relative(), Path::new("/tmp/config-example")); /// assert_eq!(rocket.config().temp_dir.relative(), Path::new("/tmp/config-example"));
/// ///
/// // Create a new figment which modifies _some_ keys the existing figment: /// // Create a new figment which modifies _some_ keys the existing figment:
/// let figment = rocket.figment().clone() /// let figment = rocket.figment().clone()
/// .merge((Config::PORT, 8888)) /// .merge((Config::IDENT, "Example"));
/// .merge((Config::ADDRESS, "171.64.200.10"));
/// ///
/// let rocket = rocket::custom(&config) /// let rocket = rocket::custom(&config)
/// .configure(figment) /// .configure(figment)
/// .ignite().await?; /// .ignite().await?;
/// ///
/// assert_eq!(rocket.config().port, 8888); /// assert_eq!(rocket.config().ident.as_str(), Some("Example"));
/// assert_eq!(rocket.config().address, Ipv4Addr::new(171, 64, 200, 10));
/// assert_eq!(rocket.config().temp_dir.relative(), Path::new("/tmp/config-example")); /// assert_eq!(rocket.config().temp_dir.relative(), Path::new("/tmp/config-example"));
/// # Ok(()) /// # Ok(())
/// # }); /// # });
@ -664,8 +660,9 @@ impl Rocket<Ignite> {
self.shutdown.clone() self.shutdown.clone()
} }
fn into_orbit(self) -> Rocket<Orbit> { pub(crate) fn into_orbit(self, address: Endpoint) -> Rocket<Orbit> {
Rocket(Orbiting { Rocket(Orbiting {
endpoint: address,
router: self.0.router, router: self.0.router,
fairings: self.0.fairings, fairings: self.0.fairings,
figment: self.0.figment, figment: self.0.figment,
@ -675,28 +672,24 @@ impl Rocket<Ignite> {
}) })
} }
async fn _local_launch(self) -> Rocket<Orbit> { async fn _local_launch(self, addr: Endpoint) -> Rocket<Orbit> {
let rocket = self.into_orbit(); let rocket = self.into_orbit(addr);
rocket.fairings.handle_liftoff(&rocket).await; Rocket::liftoff(&rocket).await;
launch_info!("{}{}", "🚀 ".emoji(), "Rocket has launched locally".primary().bold());
rocket rocket
} }
async fn _launch(self) -> Result<Rocket<Ignite>, Error> { async fn _launch(self) -> Result<Rocket<Ignite>, Error> {
self.into_orbit() let config = self.figment().extract::<DefaultListener>()?;
.default_tcp_http_server(|rkt| Box::pin(async move { either::for_both!(config.base_bindable()?, base => {
rkt.fairings.handle_liftoff(&rkt).await; either::for_both!(config.tls_bindable(base), bindable => {
self._launch_on(bindable).await
})
})
}
let proto = rkt.config.tls_enabled().then(|| "https").unwrap_or("http"); async fn _launch_on<B: Bindable>(self, bindable: B) -> Result<Rocket<Ignite>, Error> {
let socket_addr = SocketAddr::new(rkt.config.address, rkt.config.port); let listener = bindable.bind().await.map_err(|e| ErrorKind::Bind(Box::new(e)))?;
let addr = format!("{}://{}", proto, socket_addr); self.serve(listener).await
launch_info!("{}{} {}",
"🚀 ".emoji(),
"Rocket has launched from".bold().primary().linger(),
addr.underline());
}))
.await
.map(|rocket| rocket.into_ignite())
} }
} }
@ -712,6 +705,21 @@ impl Rocket<Orbit> {
}) })
} }
pub(crate) async fn liftoff<R: Deref<Target = Self>>(rocket: R) {
let rocket = rocket.deref();
rocket.fairings.handle_liftoff(rocket).await;
if !crate::running_within_rocket_async_rt().await {
warn!("Rocket is executing inside of a custom runtime.");
info_!("Rocket's runtime is enabled via `#[rocket::main]` or `#[launch]`.");
info_!("Forced shutdown is disabled. Runtime settings may be suboptimal.");
}
launch_info!("{}{} {}", "🚀 ".emoji(),
"Rocket has launched on".bold().primary().linger(),
rocket.endpoint().underline());
}
/// Returns the finalized, active configuration. This is guaranteed to /// Returns the finalized, active configuration. This is guaranteed to
/// remain stable after [`Rocket::ignite()`], through ignition and into /// remain stable after [`Rocket::ignite()`], through ignition and into
/// orbit. /// orbit.
@ -734,6 +742,10 @@ impl Rocket<Orbit> {
&self.config &self.config
} }
pub fn endpoint(&self) -> &Endpoint {
&self.endpoint
}
/// Returns a handle which can be used to trigger a shutdown and detect a /// Returns a handle which can be used to trigger a shutdown and detect a
/// triggered shutdown. /// triggered shutdown.
/// ///
@ -867,10 +879,10 @@ impl<P: Phase> Rocket<P> {
} }
} }
pub(crate) async fn local_launch(self) -> Result<Rocket<Orbit>, Error> { pub(crate) async fn local_launch(self, l: Endpoint) -> Result<Rocket<Orbit>, Error> {
let rocket = match self.0.into_state() { let rocket = match self.0.into_state() {
State::Build(s) => Rocket::from(s).ignite().await?._local_launch().await, State::Build(s) => Rocket::from(s).ignite().await?._local_launch(l).await,
State::Ignite(s) => Rocket::from(s)._local_launch().await, State::Ignite(s) => Rocket::from(s)._local_launch(l).await,
State::Orbit(s) => Rocket::from(s) State::Orbit(s) => Rocket::from(s)
}; };
@ -928,6 +940,14 @@ impl<P: Phase> Rocket<P> {
State::Orbit(s) => Ok(Rocket::from(s).into_ignite()) State::Orbit(s) => Ok(Rocket::from(s).into_ignite())
} }
} }
pub async fn launch_on<B: Bindable>(self, bindable: B) -> Result<Rocket<Ignite>, Error> {
match self.0.into_state() {
State::Build(s) => Rocket::from(s).ignite().await?._launch_on(bindable).await,
State::Ignite(s) => Rocket::from(s)._launch_on(bindable).await,
State::Orbit(s) => Ok(Rocket::from(s).into_ignite())
}
}
} }
#[doc(hidden)] #[doc(hidden)]

View File

@ -167,7 +167,6 @@ impl<F: Clone + Sync + Send + 'static> Handler for F
} }
} }
// FIXME!
impl<'r, 'o: 'r> Outcome<'o> { impl<'r, 'o: 'r> Outcome<'o> {
/// Return the `Outcome` of response to `req` from `responder`. /// Return the `Outcome` of response to `req` from `responder`.
/// ///

View File

@ -1,540 +1,142 @@
use std::io; use std::io;
use std::pin::pin;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use std::pin::Pin;
use yansi::Paint; use hyper::service::service_fn;
use tokio::sync::oneshot; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use hyper_util::server::conn::auto::Builder;
use futures::{Future, TryFutureExt, future::{select, Either::*}};
use tokio::time::sleep; use tokio::time::sleep;
use futures::stream::StreamExt;
use futures::future::{FutureExt, Future, BoxFuture};
use crate::{route, Rocket, Orbit, Request, Response, Data, Config}; use crate::{Request, Rocket, Orbit, Data, Ignite};
use crate::form::Form;
use crate::outcome::Outcome;
use crate::error::{Error, ErrorKind};
use crate::ext::{AsyncReadExt, CancellableListener, CancellableIo};
use crate::request::ConnectionMeta; use crate::request::ConnectionMeta;
use crate::data::IoHandler; use crate::erased::{ErasedRequest, ErasedResponse, ErasedIoHandler};
use crate::listener::{Listener, CancellableExt, BouncedExt};
use crate::http::{hyper, uncased, Method, Status, Header}; use crate::error::{Error, ErrorKind};
use crate::http::private::{TcpListener, Listener, Connection, Incoming}; use crate::data::IoStream;
use crate::util::ReaderStream;
// A token returned to force the execution of one method before another. use crate::http::Status;
pub(crate) struct RequestToken;
async fn handle<Fut, T, F>(name: Option<&str>, run: F) -> Option<T>
where F: FnOnce() -> Fut, Fut: Future<Output = T>,
{
use std::panic::AssertUnwindSafe;
macro_rules! panic_info {
($name:expr, $e:expr) => {{
match $name {
Some(name) => error_!("Handler {} panicked.", name.primary()),
None => error_!("A handler panicked.")
};
info_!("This is an application bug.");
info_!("A panic in Rust must be treated as an exceptional event.");
info_!("Panicking is not a suitable error handling mechanism.");
info_!("Unwinding, the result of a panic, is an expensive operation.");
info_!("Panics will degrade application performance.");
info_!("Instead of panicking, return `Option` and/or `Result`.");
info_!("Values of either type can be returned directly from handlers.");
warn_!("A panic is treated as an internal server error.");
$e
}}
}
let run = AssertUnwindSafe(run);
let fut = std::panic::catch_unwind(move || run())
.map_err(|e| panic_info!(name, e))
.ok()?;
AssertUnwindSafe(fut)
.catch_unwind()
.await
.map_err(|e| panic_info!(name, e))
.ok()
}
// This function tries to hide all of the Hyper-ness from Rocket. It essentially
// converts Hyper types into Rocket types, then calls the `dispatch` function,
// which knows nothing about Hyper. Because responding depends on the
// `HyperResponse` type, this function does the actual response processing.
async fn hyper_service_fn(
rocket: Arc<Rocket<Orbit>>,
conn: ConnectionMeta,
mut hyp_req: hyper::Request<hyper::Body>,
) -> Result<hyper::Response<hyper::Body>, io::Error> {
// This future must return a hyper::Response, but the response body might
// borrow from the request. Instead, write the body in another future that
// sends the response metadata (and a body channel) prior.
let (tx, rx) = oneshot::channel();
#[cfg(not(broken_fmt))]
debug!("received request: {:#?}", hyp_req);
tokio::spawn(async move {
// We move the request next, so get the upgrade future now.
let pending_upgrade = hyper::upgrade::on(&mut hyp_req);
// Convert a Hyper request into a Rocket request.
let (h_parts, mut h_body) = hyp_req.into_parts();
match Request::from_hyp(&rocket, &h_parts, Some(conn)) {
Ok(mut req) => {
// Convert into Rocket `Data`, dispatch request, write response.
let mut data = Data::from(&mut h_body);
let token = rocket.preprocess_request(&mut req, &mut data).await;
let mut response = rocket.dispatch(token, &req, data).await;
let upgrade = response.take_upgrade(req.headers().get("upgrade"));
if let Ok(Some((proto, handler))) = upgrade {
rocket.handle_upgrade(response, proto, handler, pending_upgrade, tx).await;
} else {
if upgrade.is_err() {
warn_!("Request wants upgrade but no I/O handler matched.");
info_!("Request is not being upgraded.");
}
rocket.send_response(response, tx).await;
}
},
Err(e) => {
warn!("Bad incoming HTTP request.");
e.errors.iter().for_each(|e| warn_!("Error: {}.", e));
warn_!("Dispatching salvaged request to catcher: {}.", e.request);
let response = rocket.handle_error(Status::BadRequest, &e.request).await;
rocket.send_response(response, tx).await;
}
}
});
// Receive the response written to `tx` by the task above.
rx.await.map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))
}
impl Rocket<Orbit> { impl Rocket<Orbit> {
/// Wrapper around `_send_response` to log a success or error. async fn service(
#[inline] self: Arc<Self>,
async fn send_response( mut req: hyper::Request<hyper::body::Incoming>,
&self, connection: ConnectionMeta,
response: Response<'_>, ) -> Result<hyper::Response<ReaderStream<ErasedResponse>>, http::Error> {
tx: oneshot::Sender<hyper::Response<hyper::Body>>, let upgrade = hyper::upgrade::on(&mut req);
) { let (parts, incoming) = req.into_parts();
let remote_hungup = |e: &io::Error| match e.kind() { let request = ErasedRequest::new(self, parts, |rocket, parts| {
| io::ErrorKind::BrokenPipe Request::from_hyp(rocket, parts, connection).unwrap_or_else(|e| e)
| io::ErrorKind::ConnectionReset });
| io::ErrorKind::ConnectionAborted => true,
_ => false,
};
match self._send_response(response, tx).await { let mut response = request.into_response(
Ok(()) => info_!("{}", "Response succeeded.".green()), incoming,
Err(e) if remote_hungup(&e) => warn_!("Remote left: {}.", e), |incoming| Data::from(incoming),
Err(e) => warn_!("Failed to write response: {}.", e), |rocket, request, data| Box::pin(rocket.preprocess(request, data)),
} |token, rocket, request, data| Box::pin(async move {
} if !request.errors.is_empty() {
return rocket.dispatch_error(Status::BadRequest, request).await;
/// Attempts to create a hyper response from `response` and send it to `tx`.
#[inline]
async fn _send_response(
&self,
mut response: Response<'_>,
tx: oneshot::Sender<hyper::Response<hyper::Body>>,
) -> io::Result<()> {
let mut hyp_res = hyper::Response::builder();
hyp_res = hyp_res.status(response.status().code);
for header in response.headers().iter() {
let name = header.name.as_str();
let value = header.value.as_bytes();
hyp_res = hyp_res.header(name, value);
}
let body = response.body_mut();
if let Some(n) = body.size().await {
hyp_res = hyp_res.header(hyper::header::CONTENT_LENGTH, n);
}
let (mut sender, hyp_body) = hyper::Body::channel();
let hyp_response = hyp_res.body(hyp_body)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
#[cfg(not(broken_fmt))]
debug!("sending response: {:#?}", hyp_response);
tx.send(hyp_response).map_err(|_| {
let msg = "client disconnect before response started";
io::Error::new(io::ErrorKind::BrokenPipe, msg)
})?;
let max_chunk_size = body.max_chunk_size();
let mut stream = body.into_bytes_stream(max_chunk_size);
while let Some(next) = stream.next().await {
sender.send_data(next?).await
.map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))?;
}
Ok(())
}
async fn handle_upgrade<'r>(
&self,
mut response: Response<'r>,
proto: uncased::Uncased<'r>,
io_handler: Pin<Box<dyn IoHandler + 'r>>,
pending_upgrade: hyper::upgrade::OnUpgrade,
tx: oneshot::Sender<hyper::Response<hyper::Body>>,
) {
info_!("Upgrading connection to {}.", Paint::white(&proto).bold());
response.set_status(Status::SwitchingProtocols);
response.set_raw_header("Connection", "Upgrade");
response.set_raw_header("Upgrade", proto.clone().into_cow());
self.send_response(response, tx).await;
match pending_upgrade.await {
Ok(io_stream) => {
info_!("Upgrade successful.");
if let Err(e) = io_handler.io(io_stream.into()).await {
if e.kind() == io::ErrorKind::BrokenPipe {
warn!("Upgraded {} I/O handler was closed.", proto);
} else {
error!("Upgraded {} I/O handler failed: {}", proto, e);
}
}
},
Err(e) => {
warn!("Response indicated upgrade, but upgrade failed.");
warn_!("Upgrade error: {}", e);
}
}
}
/// Preprocess the request for Rocket things. Currently, this means:
///
/// * Rewriting the method in the request if _method form field exists.
/// * Run the request fairings.
///
/// Keep this in-sync with derive_form when preprocessing form fields.
pub(crate) async fn preprocess_request(
&self,
req: &mut Request<'_>,
data: &mut Data<'_>
) -> RequestToken {
// Check if this is a form and if the form contains the special _method
// field which we use to reinterpret the request's method.
let (min_len, max_len) = ("_method=get".len(), "_method=delete".len());
let peek_buffer = data.peek(max_len).await;
let is_form = req.content_type().map_or(false, |ct| ct.is_form());
if is_form && req.method() == Method::Post && peek_buffer.len() >= min_len {
let method = std::str::from_utf8(peek_buffer).ok()
.and_then(|raw_form| Form::values(raw_form).next())
.filter(|field| field.name == "_method")
.and_then(|field| field.value.parse().ok());
if let Some(method) = method {
req._set_method(method);
}
}
// Run request fairings.
self.fairings.handle_request(req, data).await;
RequestToken
}
#[inline]
pub(crate) async fn dispatch<'s, 'r: 's>(
&'s self,
_token: RequestToken,
request: &'r Request<'s>,
data: Data<'r>
) -> Response<'r> {
info!("{}:", request);
// Remember if the request is `HEAD` for later body stripping.
let was_head_request = request.method() == Method::Head;
// Route the request and run the user's handlers.
let mut response = self.route_and_process(request, data).await;
// Add a default 'Server' header if it isn't already there.
// TODO: If removing Hyper, write out `Date` header too.
if let Some(ident) = request.rocket().config.ident.as_str() {
if !response.headers().contains("Server") {
response.set_header(Header::new("Server", ident));
}
}
// Run the response fairings.
self.fairings.handle_response(request, &mut response).await;
// Strip the body if this is a `HEAD` request.
if was_head_request {
response.strip_body();
} }
let mut response = rocket.dispatch(token, request, data).await;
response.body_mut().size().await;
response response
})
).await;
let io_handler = response.to_io_handler(Rocket::extract_io_handler);
if let Some(handler) = io_handler {
let upgrade = upgrade.map_ok(IoStream::from).map_err(io::Error::other);
tokio::task::spawn(io_handler_task(upgrade, handler));
} }
async fn route_and_process<'s, 'r: 's>( let mut builder = hyper::Response::builder();
&'s self, builder = builder.status(response.inner().status().code);
request: &'r Request<'s>, for header in response.inner().headers().iter() {
data: Data<'r> builder = builder.header(header.name().as_str(), header.value());
) -> Response<'r> { }
let mut response = match self.route(request, data).await {
Outcome::Success(response) => response,
Outcome::Forward((data, _)) if request.method() == Method::Head => {
info_!("Autohandling {} request.", "HEAD".primary().bold());
// Dispatch the request again with Method `GET`. if let Some(size) = response.inner().body().preset_size() {
request._set_method(Method::Get); builder = builder.header("Content-Length", size);
match self.route(request, data).await {
Outcome::Success(response) => response,
Outcome::Error(status) => self.handle_error(status, request).await,
Outcome::Forward((_, status)) => self.handle_error(status, request).await,
} }
let chunk_size = response.inner().body().max_chunk_size();
builder.body(ReaderStream::with_capacity(response, chunk_size))
} }
Outcome::Forward((_, status)) => self.handle_error(status, request).await, }
Outcome::Error(status) => self.handle_error(status, request).await,
async fn io_handler_task<S>(stream: S, mut handler: ErasedIoHandler)
where S: Future<Output = io::Result<IoStream>>
{
let stream = match stream.await {
Ok(stream) => stream,
Err(e) => return warn_!("Upgrade failed: {e}"),
}; };
// Set the cookies. Note that error responses will only include cookies info_!("Upgrade succeeded.");
// set by the error handler. See `handle_error` for more. if let Err(e) = handler.take().io(stream).await {
let delta_jar = request.cookies().take_delta_jar(); match e.kind() {
for cookie in delta_jar.delta() { io::ErrorKind::BrokenPipe => warn!("Upgrade I/O handler was closed."),
response.adjoin_header(cookie); e => error!("Upgrade I/O handler failed: {e}"),
}
response
}
/// Tries to find a `Responder` for a given `request`. It does this by
/// routing the request and calling the handler for each matching route
/// until one of the handlers returns success or error, or there are no
/// additional routes to try (forward). The corresponding outcome for each
/// condition is returned.
#[inline]
async fn route<'s, 'r: 's>(
&'s self,
request: &'r Request<'s>,
mut data: Data<'r>,
) -> route::Outcome<'r> {
// Go through all matching routes until we fail or succeed or run out of
// routes to try, in which case we forward with the last status.
let mut status = Status::NotFound;
for route in self.router.route(request) {
// Retrieve and set the requests parameters.
info_!("Matched: {}", route);
request.set_route(route);
let name = route.name.as_deref();
let outcome = handle(name, || route.handler.handle(request, data)).await
.unwrap_or(Outcome::Error(Status::InternalServerError));
// Check if the request processing completed (Some) or if the
// request needs to be forwarded. If it does, continue the loop
// (None) to try again.
info_!("{}", outcome.log_display());
match outcome {
o@Outcome::Success(_) | o@Outcome::Error(_) => return o,
Outcome::Forward(forwarded) => (data, status) = forwarded,
} }
} }
}
error_!("No matching routes for {}.", request); impl Rocket<Ignite> {
Outcome::Forward((data, status)) pub(crate) async fn serve<L>(self, listener: L) -> Result<Self, crate::Error>
} where L: Listener + 'static
/// Invokes the handler with `req` for catcher with status `status`.
///
/// In order of preference, invoked handler is:
/// * the user's registered handler for `status`
/// * the user's registered `default` handler
/// * Rocket's default handler for `status`
///
/// Return `Ok(result)` if the handler succeeded. Returns `Ok(Some(Status))`
/// if the handler ran to completion but failed. Returns `Ok(None)` if the
/// handler panicked while executing.
async fn invoke_catcher<'s, 'r: 's>(
&'s self,
status: Status,
req: &'r Request<'s>
) -> Result<Response<'r>, Option<Status>> {
// For now, we reset the delta state to prevent any modifications
// from earlier, unsuccessful paths from being reflected in error
// response. We may wish to relax this in the future.
req.cookies().reset_delta();
if let Some(catcher) = self.router.catch(status, req) {
warn_!("Responding with registered {} catcher.", catcher);
let name = catcher.name.as_deref();
handle(name, || catcher.handler.handle(status, req)).await
.map(|result| result.map_err(Some))
.unwrap_or_else(|| Err(None))
} else {
let code = status.code.blue().bold();
warn_!("No {} catcher registered. Using Rocket default.", code);
Ok(crate::catcher::default_handler(status, req))
}
}
// Invokes the catcher for `status`. Returns the response on success.
//
// On catcher error, the 500 error catcher is attempted. If _that_ errors,
// the (infallible) default 500 error cather is used.
pub(crate) async fn handle_error<'s, 'r: 's>(
&'s self,
mut status: Status,
req: &'r Request<'s>
) -> Response<'r> {
// Dispatch to the `status` catcher.
if let Ok(r) = self.invoke_catcher(status, req).await {
return r;
}
// If it fails and it's not a 500, try the 500 catcher.
if status != Status::InternalServerError {
error_!("Catcher failed. Attempting 500 error catcher.");
status = Status::InternalServerError;
if let Ok(r) = self.invoke_catcher(status, req).await {
return r;
}
}
// If it failed again or if it was already a 500, use Rocket's default.
error_!("{} catcher failed. Using Rocket default 500.", status.code);
crate::catcher::default_handler(Status::InternalServerError, req)
}
pub(crate) async fn default_tcp_http_server<C>(mut self, ready: C) -> Result<Self, Error>
where C: for<'a> Fn(&'a Self) -> BoxFuture<'a, ()>
{ {
use std::net::ToSocketAddrs; let mut builder = Builder::new(TokioExecutor::new());
let keep_alive = Duration::from_secs(self.config.keep_alive.into());
builder.http1()
.half_close(true)
.timer(TokioTimer::new())
.keep_alive(keep_alive > Duration::ZERO)
.preserve_header_case(true)
.header_read_timeout(Duration::from_secs(15));
// Determine the address we're going to serve on. #[cfg(feature = "http2")] {
let addr = format!("{}:{}", self.config.address, self.config.port); builder.http2().timer(TokioTimer::new());
let mut addr = addr.to_socket_addrs() if keep_alive > Duration::ZERO {
.map(|mut addrs| addrs.next().expect(">= 1 socket addr")) builder.http2()
.map_err(|e| Error::new(ErrorKind::Io(e)))?; .timer(TokioTimer::new())
.keep_alive_interval(keep_alive / 4)
#[cfg(feature = "tls")] .keep_alive_timeout(keep_alive);
if self.config.tls_enabled() {
if let Some(ref config) = self.config.tls {
use crate::http::tls::TlsListener;
let conf = config.to_native_config().map_err(ErrorKind::Io)?;
let l = TlsListener::bind(addr, conf).await.map_err(ErrorKind::TlsBind)?;
addr = l.local_addr().unwrap_or(addr);
self.config.address = addr.ip();
self.config.port = addr.port();
ready(&mut self).await;
return self.http_server(l).await;
} }
} }
let l = TcpListener::bind(addr).await.map_err(ErrorKind::Bind)?; let listener = listener.bounced().cancellable(self.shutdown(), &self.config.shutdown);
addr = l.local_addr().unwrap_or(addr); let rocket = Arc::new(self.into_orbit(listener.socket_addr()?));
self.config.address = addr.ip(); let _ = tokio::spawn(Rocket::liftoff(rocket.clone())).await;
self.config.port = addr.port();
ready(&mut self).await; let (server, listener) = (Arc::new(builder), Arc::new(listener));
self.http_server(l).await while let Some(accept) = listener.accept_next().await {
let (listener, rocket, server) = (listener.clone(), rocket.clone(), server.clone());
tokio::spawn({
let result = async move {
let conn = TokioIo::new(listener.connect(accept).await?);
let meta = ConnectionMeta::from(conn.inner());
let service = service_fn(|req| rocket.clone().service(req, meta.clone()));
let serve = pin!(server.serve_connection_with_upgrades(conn, service));
match select(serve, rocket.shutdown()).await {
Left((result, _)) => result,
Right((_, mut conn)) => {
conn.as_mut().graceful_shutdown();
conn.await
} }
// TODO.async: Solidify the Listener APIs and make this function public
pub(crate) async fn http_server<L>(self, listener: L) -> Result<Self, Error>
where L: Listener + Send, <L as Listener>::Connection: Send + Unpin + 'static
{
// Emit a warning if we're not running inside of Rocket's async runtime.
if self.config.profile == Config::DEBUG_PROFILE {
tokio::task::spawn_blocking(|| {
let this = std::thread::current();
if !this.name().map_or(false, |s| s.starts_with("rocket-worker")) {
warn!("Rocket is executing inside of a custom runtime.");
info_!("Rocket's runtime is enabled via `#[rocket::main]` or `#[launch]`.");
info_!("Forced shutdown is disabled. Runtime settings may be suboptimal.");
}
});
}
// Set up cancellable I/O from the given listener. Shutdown occurs when
// `Shutdown` (`TripWire`) resolves. This can occur directly through a
// notification or indirectly through an external signal which, when
// received, results in triggering the notify.
let shutdown = self.shutdown();
let sig_stream = self.config.shutdown.signal_stream();
let grace = self.config.shutdown.grace as u64;
let mercy = self.config.shutdown.mercy as u64;
// Start a task that listens for external signals and notifies shutdown.
if let Some(mut stream) = sig_stream {
let shutdown = shutdown.clone();
tokio::spawn(async move {
while let Some(sig) = stream.next().await {
if shutdown.0.tripped() {
warn!("Received {}. Shutdown already in progress.", sig);
} else {
warn!("Received {}. Requesting shutdown.", sig);
}
shutdown.0.trip();
}
});
}
// Save the keep-alive value for later use; we're about to move `self`.
let keep_alive = self.config.keep_alive;
// Create the Hyper `Service`.
let rocket = Arc::new(self);
let service_fn = |conn: &CancellableIo<_, L::Connection>| {
let rocket = rocket.clone();
let connection = ConnectionMeta {
remote: conn.peer_address(),
client_certificates: conn.peer_certificates(),
};
async move {
Ok::<_, std::convert::Infallible>(hyper::service::service_fn(move |req| {
hyper_service_fn(rocket.clone(), connection.clone(), req)
}))
} }
}; };
// NOTE: `hyper` uses `tokio::spawn()` as the default executor. result.inspect_err(crate::error::log_server_error)
let listener = CancellableListener::new(shutdown.clone(), listener, grace, mercy);
let builder = hyper::server::Server::builder(Incoming::new(listener).nodelay(true));
#[cfg(feature = "http2")]
let builder = builder.http2_keep_alive_interval(match keep_alive {
0 => None,
n => Some(Duration::from_secs(n as u64))
}); });
}
let server = builder // Rocket wraps all connections in a `CancellableIo` struct, an internal
.http1_keepalive(keep_alive != 0) // structure that gracefully closes I/O when it receives a signal. That
.http1_preserve_header_case(true) // signal is the `shutdown` future. When the future resolves,
.serve(hyper::service::make_service_fn(service_fn)) // `CancellableIo` begins to terminate in grace, mercy, and finally
.with_graceful_shutdown(shutdown.clone()); // force close phases. Since all connections are wrapped in
// This deserves some explanation.
//
// This is largely to deal with Hyper's dreadful and largely nonexistent
// handling of shutdown, in general, nevermind graceful.
//
// When Hyper receives a "graceful shutdown" request, it stops accepting
// new requests. That's it. It continues to process existing requests
// and outgoing responses forever and never cancels them. As a result,
// Rocket must take it upon itself to cancel any existing I/O.
//
// To do so, Rocket wraps all connections in a `CancellableIo` struct,
// an internal structure that gracefully closes I/O when it receives a
// signal. That signal is the `shutdown` future. When the future
// resolves, `CancellableIo` begins to terminate in grace, mercy, and
// finally force close phases. Since all connections are wrapped in
// `CancellableIo`, this eventually ends all I/O. // `CancellableIo`, this eventually ends all I/O.
// //
// At that point, unless a user spawned an infinite, stand-alone task // At that point, unless a user spawned an infinite, stand-alone task
@ -543,69 +145,35 @@ impl Rocket<Orbit> {
// we can return the owned instance of `Rocket`. // we can return the owned instance of `Rocket`.
// //
// Unfortunately, the Hyper `server` future resolves as soon as it has // Unfortunately, the Hyper `server` future resolves as soon as it has
// finishes processing requests without respect for ongoing responses. // finished processing requests without respect for ongoing responses.
// That is, `server` resolves even when there are running tasks that are // That is, `server` resolves even when there are running tasks that are
// generating a response. So, `server` resolving implies little to // generating a response. So, `server` resolving implies little to
// nothing about the state of connections. As a result, we depend on the // nothing about the state of connections. As a result, we depend on the
// timing of grace + mercy + some buffer to determine when all // timing of grace + mercy + some buffer to determine when all
// connections should be closed, thus all tasks should be complete, thus // connections should be closed, thus all tasks should be complete, thus
// all references to `Arc<Rocket>` should be dropped and we can get a // all references to `Arc<Rocket>` should be dropped and we can get back
// unique reference. // a unique reference.
tokio::pin!(server); info!("Shutting down. Waiting for shutdown fairings and pending I/O...");
tokio::select! { tokio::spawn({
biased; let rocket = rocket.clone();
async move { rocket.fairings.handle_shutdown(&*rocket).await }
});
_ = shutdown => { let config = &rocket.config.shutdown;
// Run shutdown fairings. We compute `sleep()` for grace periods let wait = Duration::from_micros(250);
// beforehand to ensure we don't add shutdown fairing completion for period in [wait, config.grace(), wait, config.mercy(), wait * 4] {
// time, which is arbitrary, to these periods. if Arc::strong_count(&rocket) == 1 { break }
info!("Shutdown requested. Waiting for pending I/O..."); sleep(period).await;
let grace_timer = sleep(Duration::from_secs(grace));
let mercy_timer = sleep(Duration::from_secs(grace + mercy));
let shutdown_timer = sleep(Duration::from_secs(grace + mercy + 1));
rocket.fairings.handle_shutdown(&*rocket).await;
tokio::pin!(grace_timer, mercy_timer, shutdown_timer);
tokio::select! {
biased;
result = &mut server => {
if let Err(e) = result {
warn!("Server failed while shutting down: {}", e);
return Err(Error::shutdown(rocket.clone(), e));
} }
if Arc::strong_count(&rocket) != 1 { grace_timer.await; }
if Arc::strong_count(&rocket) != 1 { mercy_timer.await; }
if Arc::strong_count(&rocket) != 1 { shutdown_timer.await; }
match Arc::try_unwrap(rocket) { match Arc::try_unwrap(rocket) {
Ok(rocket) => { Ok(rocket) => {
info!("Graceful shutdown completed successfully."); info!("Graceful shutdown completed successfully.");
Ok(rocket) Ok(rocket.into_ignite())
} }
Err(rocket) => { Err(rocket) => {
warn!("Shutdown failed: outstanding background I/O."); warn!("Shutdown failed: outstanding background I/O.");
Err(Error::shutdown(rocket, None)) Err(Error::new(ErrorKind::Shutdown(rocket)))
}
}
}
_ = &mut shutdown_timer => {
warn!("Shutdown failed: server executing after timeouts.");
return Err(Error::shutdown(rocket.clone(), None));
},
}
}
result = &mut server => {
match result {
Ok(()) => {
info!("Server shutdown nominally.");
Ok(Arc::try_unwrap(rocket).map_err(|r| Error::shutdown(r, None))?)
}
Err(e) => {
info!("Server failed prior to shutdown: {}:", e);
Err(Error::shutdown(rocket.clone(), e))
}
}
} }
} }
} }

View File

@ -198,7 +198,7 @@ impl Fairing for Shield {
} }
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) { async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
let force_hsts = rocket.config().tls_enabled() let force_hsts = rocket.endpoint().is_tls()
&& rocket.figment().profile() != Config::DEBUG_PROFILE && rocket.figment().profile() != Config::DEBUG_PROFILE
&& !self.is_enabled::<Hsts>(); && !self.is_enabled::<Hsts>();

View File

@ -5,7 +5,7 @@ use std::pin::Pin;
use futures::FutureExt; use futures::FutureExt;
use crate::request::{FromRequest, Outcome, Request}; use crate::request::{FromRequest, Outcome, Request};
use crate::trip_wire::TripWire; use crate::util::TripWire;
/// A request guard and future for graceful shutdown. /// A request guard and future for graceful shutdown.
/// ///

File diff suppressed because it is too large Load Diff

View File

@ -11,6 +11,7 @@ pub enum KeyError {
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
Io(std::io::Error), Io(std::io::Error),
Bind(Box<dyn std::error::Error + Send + 'static>),
Tls(rustls::Error), Tls(rustls::Error),
Mtls(rustls::server::VerifierBuilderError), Mtls(rustls::server::VerifierBuilderError),
CertChain(std::io::Error), CertChain(std::io::Error),
@ -29,6 +30,7 @@ impl std::fmt::Display for Error {
CertChain(e) => write!(f, "failed to process certificate chain: {e}"), CertChain(e) => write!(f, "failed to process certificate chain: {e}"),
PrivKey(e) => write!(f, "failed to process private key: {e}"), PrivKey(e) => write!(f, "failed to process private key: {e}"),
CertAuth(e) => write!(f, "failed to process certificate authority: {e}"), CertAuth(e) => write!(f, "failed to process certificate authority: {e}"),
Bind(e) => write!(f, "failed to bind to network interface: {e}"),
} }
} }
} }
@ -66,6 +68,7 @@ impl std::error::Error for Error {
Error::CertChain(e) => Some(e), Error::CertChain(e) => Some(e),
Error::PrivKey(e) => Some(e), Error::PrivKey(e) => Some(e),
Error::CertAuth(e) => Some(e), Error::CertAuth(e) => Some(e),
Error::Bind(e) => Some(&**e),
} }
} }
} }

7
core/lib/src/tls/mod.rs Normal file
View File

@ -0,0 +1,7 @@
mod error;
pub(crate) mod config;
pub(crate) mod util;
pub use error::Result;
pub use config::{TlsConfig, CipherSuite};
pub use error::Error;

View File

@ -0,0 +1,52 @@
use std::io;
use std::task::{Poll, Context};
use std::pin::Pin;
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, ReadBuf};
pin_project! {
/// Stream for the [`chain`](super::AsyncReadExt::chain) method.
#[must_use = "streams do nothing unless polled"]
pub struct Chain<T, U> {
#[pin]
first: Option<T>,
#[pin]
second: U,
}
}
impl<T, U> Chain<T, U> {
pub(crate) fn new(first: T, second: U) -> Self {
Self { first: Some(first), second }
}
}
impl<T: AsyncRead, U: AsyncRead> Chain<T, U> {
/// Gets references to the underlying readers in this `Chain`.
pub fn get_ref(&self) -> (Option<&T>, &U) {
(self.first.as_ref(), &self.second)
}
}
impl<T: AsyncRead, U: AsyncRead> AsyncRead for Chain<T, U> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let me = self.as_mut().project();
if let Some(first) = me.first.as_pin_mut() {
let init_rem = buf.remaining();
futures::ready!(first.poll_read(cx, buf))?;
if buf.remaining() == init_rem {
self.as_mut().project().first.set(None);
} else {
return Poll::Ready(Ok(()));
}
}
let me = self.as_mut().project();
me.second.poll_read(cx, buf)
}
}

77
core/lib/src/util/join.rs Normal file
View File

@ -0,0 +1,77 @@
use std::pin::Pin;
use std::task::{Poll, Context};
use pin_project_lite::pin_project;
use futures::stream::Stream;
use futures::ready;
/// Join two streams, `a` and `b`, into a new `Join` stream that returns items
/// from both streams, biased to `a`, until `a` finishes. The joined stream
/// completes when `a` completes, irrespective of `b`. If `b` stops producing
/// values, then the joined stream acts exactly like a fused `a`.
///
/// Values are biased to those of `a`: if `a` provides a value, it is always
/// emitted before a value provided by `b`. In other words, values from `b` are
/// emitted when and only when `a` is not producing a value.
pub fn join<A: Stream, B: Stream>(a: A, b: B) -> Join<A, B> {
Join { a, b: Some(b), done: false, }
}
pin_project! {
/// Stream returned by [`join`].
pub struct Join<T, U> {
#[pin]
a: T,
#[pin]
b: Option<U>,
// Set when `a` returns `None`.
done: bool,
}
}
impl<T, U> Stream for Join<T, U>
where T: Stream,
U: Stream<Item = T::Item>,
{
type Item = T::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T::Item>> {
if self.done {
return Poll::Ready(None);
}
let me = self.as_mut().project();
match me.a.poll_next(cx) {
Poll::Ready(opt) => {
*me.done = opt.is_none();
Poll::Ready(opt)
},
Poll::Pending => match me.b.as_pin_mut() {
None => Poll::Pending,
Some(b) => match ready!(b.poll_next(cx)) {
Some(value) => Poll::Ready(Some(value)),
None => {
self.as_mut().project().b.set(None);
Poll::Pending
}
}
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let (left_low, left_high) = self.a.size_hint();
let (right_low, right_high) = self.b.as_ref()
.map(|b| b.size_hint())
.unwrap_or_default();
let low = left_low.saturating_add(right_low);
let high = match (left_high, right_high) {
(Some(h1), Some(h2)) => h1.checked_add(h2),
_ => None,
};
(low, high)
}
}

12
core/lib/src/util/mod.rs Normal file
View File

@ -0,0 +1,12 @@
mod chain;
mod tripwire;
mod reader_stream;
mod join;
#[cfg(unix)]
pub mod unix;
pub use chain::Chain;
pub use tripwire::TripWire;
pub use reader_stream::ReaderStream;
pub use join::join;

View File

@ -0,0 +1,124 @@
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::{Bytes, BytesMut};
use futures::stream::Stream;
use pin_project_lite::pin_project;
use tokio::io::AsyncRead;
pin_project! {
/// Convert an [`AsyncRead`] into a [`Stream`] of byte chunks.
///
/// This stream is fused. It performs the inverse operation of
/// [`StreamReader`].
///
/// # Example
///
/// ```
/// # #[tokio::main]
/// # async fn main() -> std::io::Result<()> {
/// use tokio_stream::StreamExt;
/// use tokio_util::io::ReaderStream;
///
/// // Create a stream of data.
/// let data = b"hello, world!";
/// let mut stream = ReaderStream::new(&data[..]);
///
/// // Read all of the chunks into a vector.
/// let mut stream_contents = Vec::new();
/// while let Some(chunk) = stream.next().await {
/// stream_contents.extend_from_slice(&chunk?);
/// }
///
/// // Once the chunks are concatenated, we should have the
/// // original data.
/// assert_eq!(stream_contents, data);
/// # Ok(())
/// # }
/// ```
///
/// [`AsyncRead`]: tokio::io::AsyncRead
/// [`StreamReader`]: crate::io::StreamReader
/// [`Stream`]: futures_core::Stream
#[derive(Debug)]
pub struct ReaderStream<R> {
// Reader itself.
//
// This value is `None` if the stream has terminated.
#[pin]
reader: R,
// Working buffer, used to optimize allocations.
buf: BytesMut,
capacity: usize,
done: bool,
}
}
impl<R: AsyncRead> ReaderStream<R> {
/// Convert an [`AsyncRead`] into a [`Stream`] with item type
/// `Result<Bytes, std::io::Error>`,
/// with a specific read buffer initial capacity.
///
/// [`AsyncRead`]: tokio::io::AsyncRead
/// [`Stream`]: futures_core::Stream
pub fn with_capacity(reader: R, capacity: usize) -> Self {
ReaderStream {
reader: reader,
buf: BytesMut::with_capacity(capacity),
capacity,
done: false,
}
}
}
impl<R: AsyncRead> Stream for ReaderStream<R> {
type Item = std::io::Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use tokio_util::io::poll_read_buf;
let mut this = self.as_mut().project();
if *this.done {
return Poll::Ready(None);
}
if this.buf.capacity() == 0 {
this.buf.reserve(*this.capacity);
}
let reader = this.reader;
match poll_read_buf(reader, cx, &mut this.buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(err)) => {
*this.done = true;
Poll::Ready(Some(Err(err)))
}
Poll::Ready(Ok(0)) => {
*this.done = true;
Poll::Ready(None)
}
Poll::Ready(Ok(_)) => {
let chunk = this.buf.split();
Poll::Ready(Some(Ok(chunk.freeze())))
}
}
}
// fn size_hint(&self) -> (usize, Option<usize>) {
// self.reader.size_hint()
// }
}
impl<R: AsyncRead> hyper::body::Body for ReaderStream<R> {
type Data = bytes::Bytes;
type Error = std::io::Error;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<hyper::body::Frame<Self::Data>, Self::Error>>> {
self.poll_next(cx).map_ok(hyper::body::Frame::data)
}
}

25
core/lib/src/util/unix.rs Normal file
View File

@ -0,0 +1,25 @@
use std::io;
use std::os::fd::AsRawFd;
pub fn lock_exlusive_nonblocking<T: AsRawFd>(file: &T) -> io::Result<()> {
let raw_fd = file.as_raw_fd();
let res = unsafe {
libc::flock(raw_fd, libc::LOCK_EX | libc::LOCK_NB)
};
match res {
0 => Ok(()),
_ => Err(io::Error::last_os_error()),
}
}
pub fn unlock_nonblocking<T: AsRawFd>(file: &T) -> io::Result<()> {
let res = unsafe {
libc::flock(file.as_raw_fd(), libc::LOCK_UN | libc::LOCK_NB)
};
match res {
0 => Ok(()),
_ => Err(io::Error::last_os_error()),
}
}

View File

@ -1,8 +1,9 @@
#![cfg(feature = "tls")] #![cfg(feature = "tls")]
use rocket::fs::relative; use rocket::fs::relative;
use rocket::config::{Config, TlsConfig, CipherSuite};
use rocket::local::asynchronous::Client; use rocket::local::asynchronous::Client;
use rocket::tls::{TlsConfig, CipherSuite};
use rocket::figment::providers::Serialized;
#[rocket::async_test] #[rocket::async_test]
async fn can_launch_tls() { async fn can_launch_tls() {
@ -15,9 +16,8 @@ async fn can_launch_tls() {
CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
]); ]);
let rocket = rocket::custom(Config { tls: Some(tls), ..Config::debug_default() }); let config = rocket::Config::figment().merge(Serialized::defaults(tls));
let client = Client::debug(rocket).await.unwrap(); let client = Client::debug(rocket::custom(config)).await.unwrap();
client.rocket().shutdown().notify(); client.rocket().shutdown().notify();
client.rocket().shutdown().await; client.rocket().shutdown().await;
} }

View File

@ -1,3 +1,5 @@
use std::net::{SocketAddr, Ipv4Addr};
use rocket::config::Config; use rocket::config::Config;
use rocket::fairing::AdHoc; use rocket::fairing::AdHoc;
use rocket::futures::channel::oneshot; use rocket::futures::channel::oneshot;
@ -5,13 +7,13 @@ use rocket::futures::channel::oneshot;
#[rocket::async_test] #[rocket::async_test]
async fn on_ignite_fairing_can_inspect_port() { async fn on_ignite_fairing_can_inspect_port() {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
let rocket = rocket::custom(Config { port: 0, ..Config::debug_default() }) let rocket = rocket::custom(Config::debug_default())
.attach(AdHoc::on_liftoff("Send Port -> Channel", move |rocket| { .attach(AdHoc::on_liftoff("Send Port -> Channel", move |rocket| {
Box::pin(async move { Box::pin(async move {
tx.send(rocket.config().port).unwrap(); tx.send(rocket.endpoint().tcp().unwrap().port()).unwrap();
}) })
})); }));
rocket::tokio::spawn(rocket.launch()); rocket::tokio::spawn(rocket.launch_on(SocketAddr::from((Ipv4Addr::LOCALHOST, 0))));
assert_ne!(rx.await.unwrap(), 0); assert_ne!(rx.await.unwrap(), 0);
} }

View File

@ -155,7 +155,7 @@ fn inner_sentinels_detected() {
impl<'r, 'o: 'r> response::Responder<'r, 'o> for ResponderSentinel { impl<'r, 'o: 'r> response::Responder<'r, 'o> for ResponderSentinel {
fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
todo!() unimplemented!()
} }
} }

View File

@ -8,19 +8,14 @@ macro_rules! relative {
#[test] #[test]
fn tls_config_from_source() { fn tls_config_from_source() {
use rocket::config::{Config, TlsConfig}; use rocket::tls::TlsConfig;
use rocket::figment::Figment; use rocket::figment::{Figment, providers::Serialized};
let cert_path = relative!("examples/tls/private/cert.pem"); let cert_path = relative!("examples/tls/private/cert.pem");
let key_path = relative!("examples/tls/private/key.pem"); let key_path = relative!("examples/tls/private/key.pem");
let config = TlsConfig::from_paths(cert_path, key_path);
let rocket_config = Config { let tls: TlsConfig = Figment::from(Serialized::globals(config)).extract().unwrap();
tls: Some(TlsConfig::from_paths(cert_path, key_path)),
..Default::default()
};
let config: Config = Figment::from(rocket_config).extract().unwrap();
let tls = config.tls.expect("have TLS config");
assert_eq!(tls.certs().unwrap_left(), cert_path); assert_eq!(tls.certs().unwrap_left(), cert_path);
assert_eq!(tls.key().unwrap_left(), key_path); assert_eq!(tls.key().unwrap_left(), key_path);
} }

View File

@ -6,15 +6,11 @@ async fn test_config(profile: &str) {
let config = rocket.config(); let config = rocket.config();
match &*profile { match &*profile {
"debug" => { "debug" => {
assert_eq!(config.address, std::net::Ipv4Addr::LOCALHOST);
assert_eq!(config.port, 8000);
assert_eq!(config.workers, 1); assert_eq!(config.workers, 1);
assert_eq!(config.keep_alive, 0); assert_eq!(config.keep_alive, 0);
assert_eq!(config.log_level, LogLevel::Normal); assert_eq!(config.log_level, LogLevel::Normal);
} }
"release" => { "release" => {
assert_eq!(config.address, std::net::Ipv4Addr::LOCALHOST);
assert_eq!(config.port, 8000);
assert_eq!(config.workers, 12); assert_eq!(config.workers, 12);
assert_eq!(config.keep_alive, 5); assert_eq!(config.keep_alive, 5);
assert_eq!(config.log_level, LogLevel::Critical); assert_eq!(config.log_level, LogLevel::Critical);

View File

@ -74,19 +74,8 @@ fn hello(lang: Option<Lang>, opt: Options<'_>) -> String {
#[launch] #[launch]
fn rocket() -> _ { fn rocket() -> _ {
use rocket::fairing::AdHoc;
rocket::build() rocket::build()
.mount("/", routes![hello]) .mount("/", routes![hello])
.mount("/hello", routes![world, mir]) .mount("/hello", routes![world, mir])
.mount("/wave", routes![wave]) .mount("/wave", routes![wave])
.attach(AdHoc::on_request("Compatibility Normalizer", |req, _| Box::pin(async move {
if !req.uri().is_normalized_nontrailing() {
let normal = req.uri().clone().into_normalized_nontrailing();
warn!("Incoming request URI was normalized for compatibility.");
info_!("{} -> {}", req.uri(), normal);
req.set_uri(normal);
}
})))
} }

View File

@ -1,33 +1,38 @@
//! Redirect all HTTP requests to HTTPs. //! Redirect all HTTP requests to HTTPs.
use std::sync::OnceLock; use std::net::SocketAddr;
use rocket::http::Status; use rocket::http::Status;
use rocket::log::LogLevel; use rocket::log::LogLevel;
use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite, Config}; use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite};
use rocket::fairing::{Fairing, Info, Kind}; use rocket::fairing::{Fairing, Info, Kind};
use rocket::response::Redirect; use rocket::response::Redirect;
use yansi::Paint;
#[derive(Debug, Clone, Copy, Default)]
pub struct Redirector(u16);
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Redirector { pub struct Config {
pub listen_port: u16, server: rocket::Config,
pub tls_port: OnceLock<u16>, tls_addr: SocketAddr,
} }
impl Redirector { impl Redirector {
pub fn on(port: u16) -> Self { pub fn on(port: u16) -> Self {
Redirector { listen_port: port, tls_port: OnceLock::new() } Redirector(port)
} }
// Route function that gets called on every single request. // Route function that gets called on every single request.
fn redirect<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> { fn redirect<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> {
// FIXME: Check the host against a whitelist! // FIXME: Check the host against a whitelist!
let redirector = req.rocket().state::<Self>().expect("managed Self"); let config = req.rocket().state::<Config>().expect("managed Self");
if let Some(host) = req.host() { if let Some(host) = req.host() {
let domain = host.domain(); let domain = host.domain();
let https_uri = match redirector.tls_port.get() { let https_uri = match config.tls_addr.port() {
Some(443) | None => format!("https://{domain}{}", req.uri()), 443 => format!("https://{domain}{}", req.uri()),
Some(port) => format!("https://{domain}:{port}{}", req.uri()), port => format!("https://{domain}:{port}{}", req.uri()),
}; };
route::Outcome::from(req, Redirect::permanent(https_uri)).pin() route::Outcome::from(req, Redirect::permanent(https_uri)).pin()
@ -37,21 +42,12 @@ impl Redirector {
} }
// Launch an instance of Rocket than handles redirection on `self.port`. // Launch an instance of Rocket than handles redirection on `self.port`.
pub async fn try_launch(self, mut config: Config) -> Result<Rocket<Ignite>, Error> { pub async fn try_launch(self, config: Config) -> Result<Rocket<Ignite>, Error> {
use yansi::Paint;
use rocket::http::Method::*; use rocket::http::Method::*;
// Determine the port TLS is being served on.
let tls_port = self.tls_port.get_or_init(|| config.port);
// Adjust config for redirector: disable TLS, set port, disable logging.
config.tls = None;
config.port = self.listen_port;
config.log_level = LogLevel::Critical;
info!("{}{}", "🔒 ".mask(), "HTTP -> HTTPS Redirector:".magenta()); info!("{}{}", "🔒 ".mask(), "HTTP -> HTTPS Redirector:".magenta());
info_!("redirecting on insecure port {} to TLS port {}", info_!("redirecting insecure port {} to TLS port {}",
self.listen_port.yellow(), tls_port.green()); self.0.yellow(), config.tls_addr.port().green());
// Build a vector of routes to `redirect` on `<path..>` for each method. // Build a vector of routes to `redirect` on `<path..>` for each method.
let redirects = [Get, Put, Post, Delete, Options, Head, Trace, Connect, Patch] let redirects = [Get, Put, Post, Delete, Options, Head, Trace, Connect, Patch]
@ -59,10 +55,11 @@ impl Redirector {
.map(|m| Route::new(m, "/<path..>", Self::redirect)) .map(|m| Route::new(m, "/<path..>", Self::redirect))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
rocket::custom(config) let addr = SocketAddr::new(config.tls_addr.ip(), self.0);
.manage(self) rocket::custom(&config.server)
.manage(config)
.mount("/", redirects) .mount("/", redirects)
.launch() .launch_on(addr)
.await .await
} }
} }
@ -76,8 +73,24 @@ impl Fairing for Redirector {
} }
} }
async fn on_liftoff(&self, rkt: &Rocket<Orbit>) { async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
let (this, shutdown, config) = (self.clone(), rkt.shutdown(), rkt.config().clone()); let Some(tls_addr) = rocket.endpoint().tls().and_then(|tls| tls.tcp()) else {
info!("{}{}", "🔒 ".mask(), "HTTP -> HTTPS Redirector:".magenta());
warn_!("Main instance is not being served over TLS/TCP.");
warn_!("Redirector refusing to start.");
return;
};
let config = Config {
tls_addr,
server: rocket::Config {
log_level: LogLevel::Critical,
..rocket.config().clone()
},
};
let this = *self;
let shutdown = rocket.shutdown();
let _ = rocket::tokio::spawn(async move { let _ = rocket::tokio::spawn(async move {
if let Err(e) = this.try_launch(config).await { if let Err(e) = this.try_launch(config).await {
error!("Failed to start HTTP -> HTTPS redirector."); error!("Failed to start HTTP -> HTTPS redirector.");

View File

@ -1,11 +1,21 @@
use std::fs::{self, File}; use std::fs::{self, File};
use rocket::http::{CookieJar, Cookie};
use rocket::local::blocking::Client; use rocket::local::blocking::Client;
use rocket::fs::relative; use rocket::fs::relative;
#[get("/cookie")]
fn cookie(jar: &CookieJar<'_>) {
jar.add(("k1", "v1"));
jar.add_private(("k2", "v2"));
jar.add(Cookie::build(("k1u", "v1u")).secure(false));
jar.add_private(Cookie::build(("k2u", "v2u")).secure(false));
}
#[test] #[test]
fn hello_mutual() { fn hello_mutual() {
let client = Client::tracked(super::rocket()).unwrap(); let client = Client::tracked_secure(super::rocket()).unwrap();
let cert_paths = fs::read_dir(relative!("private")).unwrap() let cert_paths = fs::read_dir(relative!("private")).unwrap()
.map(|entry| entry.unwrap().path().to_string_lossy().into_owned()) .map(|entry| entry.unwrap().path().to_string_lossy().into_owned())
.filter(|path| path.ends_with("_cert.pem") && !path.ends_with("ca_cert.pem")); .filter(|path| path.ends_with("_cert.pem") && !path.ends_with("ca_cert.pem"));
@ -23,35 +33,43 @@ fn hello_mutual() {
#[test] #[test]
fn secure_cookies() { fn secure_cookies() {
use rocket::http::{CookieJar, Cookie}; let rocket = super::rocket().mount("/", routes![cookie]);
let client = Client::tracked_secure(rocket).unwrap();
#[get("/cookie")]
fn cookie(jar: &CookieJar<'_>) {
jar.add(("k1", "v1"));
jar.add_private(("k2", "v2"));
jar.add(Cookie::build(("k1u", "v1u")).secure(false));
jar.add_private(Cookie::build(("k2u", "v2u")).secure(false));
}
let client = Client::tracked(super::rocket().mount("/", routes![cookie])).unwrap();
let response = client.get("/cookie").dispatch(); let response = client.get("/cookie").dispatch();
let c1 = response.cookies().get("k1").unwrap(); let c1 = response.cookies().get("k1").unwrap();
assert_eq!(c1.secure(), Some(true));
let c2 = response.cookies().get_private("k2").unwrap(); let c2 = response.cookies().get_private("k2").unwrap();
let c3 = response.cookies().get("k1u").unwrap();
let c4 = response.cookies().get_private("k2u").unwrap();
assert_eq!(c1.secure(), Some(true));
assert_eq!(c2.secure(), Some(true)); assert_eq!(c2.secure(), Some(true));
assert_ne!(c3.secure(), Some(true));
assert_ne!(c4.secure(), Some(true));
}
let c1 = response.cookies().get("k1u").unwrap(); #[test]
assert_ne!(c1.secure(), Some(true)); fn insecure_cookies() {
let rocket = super::rocket().mount("/", routes![cookie]);
let client = Client::tracked(rocket).unwrap();
let c2 = response.cookies().get_private("k2u").unwrap(); let response = client.get("/cookie").dispatch();
assert_ne!(c2.secure(), Some(true)); let c1 = response.cookies().get("k1").unwrap();
let c2 = response.cookies().get_private("k2").unwrap();
let c3 = response.cookies().get("k1u").unwrap();
let c4 = response.cookies().get_private("k2u").unwrap();
assert_eq!(c1.secure(), None);
assert_eq!(c2.secure(), None);
assert_eq!(c3.secure(), None);
assert_eq!(c4.secure(), None);
} }
#[test] #[test]
fn hello_world() { fn hello_world() {
use rocket::listener::DefaultListener;
use rocket::config::{Config, SecretKey};
let profiles = [ let profiles = [
"rsa_sha256", "rsa_sha256",
"ecdsa_nistp256_sha256_pkcs8", "ecdsa_nistp256_sha256_pkcs8",
@ -61,11 +79,20 @@ fn hello_world() {
"ed25519", "ed25519",
]; ];
// TODO: Testing doesn't actually read keys since we don't do TLS locally.
for profile in profiles { for profile in profiles {
let config = rocket::Config::figment().select(profile); let config = Config {
let client = Client::tracked(super::rocket().configure(config)).unwrap(); secret_key: SecretKey::generate().unwrap(),
..Config::debug_default()
};
let figment = Config::figment().merge(config).select(profile);
let client = Client::tracked_secure(super::rocket().configure(figment)).unwrap();
let response = client.get("/").dispatch(); let response = client.get("/").dispatch();
assert_eq!(response.into_string().unwrap(), "Hello, world!"); assert_eq!(response.into_string().unwrap(), "Hello, world!");
let figment = client.rocket().figment();
let listener: DefaultListener = figment.extract().unwrap();
assert_eq!(figment.profile(), profile);
listener.tls.as_ref().unwrap().validate().expect("valid TLS config");
} }
} }

View File

@ -14,7 +14,7 @@
<div id="log"></div> <div id="log"></div>
</body> </body>
<script language="javascript" type="text/javascript"> <script language="javascript" type="text/javascript">
var wsUri = "ws://127.0.0.1:8000/echo"; var wsUri = "ws://127.0.0.1:8000/echo?raw";
var log; var log;
function init() { function init() {

View File

@ -20,7 +20,9 @@ fi
echo ":::: Generating the docs..." echo ":::: Generating the docs..."
pushd "${PROJECT_ROOT}" > /dev/null 2>&1 pushd "${PROJECT_ROOT}" > /dev/null 2>&1
# Set the crate version and fill in missing doc URLs with docs.rs links. # Set the crate version and fill in missing doc URLs with docs.rs links.
RUSTDOCFLAGS="-Zunstable-options --crate-version ${DOC_VERSION} --extern-html-root-url rocket=https://api.rocket.rs/rocket/" \ RUSTDOCFLAGS="-Z unstable-options \
--crate-version ${DOC_VERSION} \
--enable-index-page" \
cargo doc -Zrustdoc-map --no-deps --all-features \ cargo doc -Zrustdoc-map --no-deps --all-features \
-p rocket \ -p rocket \
-p rocket_db_pools \ -p rocket_db_pools \

View File

@ -126,10 +126,11 @@ function test_contrib() {
function test_core() { function test_core() {
FEATURES=( FEATURES=(
tokio-macros
http2
secrets secrets
tls tls
mtls mtls
http2
json json
msgpack msgpack
uuid uuid