Support QUIC and HTTP/3.

This commit adds support for HTTP/3 and QUIC under a disabled-by-default
feature `http3-preview`. The current implementation depends on modified
versions of h3 and s2n-quic-h3 which will need to be upstreamed and
published before a release is possible.

During the course of development various facets of Rocket's internal
connection handling and recent listener APIs were improved. The complete
list of changes included in this PR is:

  * A `shutdown` module was introduced.
  * `config::Shutdown` was renamed to `ShutdownConfig` and moved to
    `shutdown` while being re-exported from `config`.
  * `ListenerAddr` is now called `Endpoint`. Various methods which
    previously referred to "addresses" now refer to "endpoints".
  * `Rocket::endpoint()` was renamed to `Rocket::endpoints()` and now
    returns an iterator over the endpoints Rocket is listening on.
  * `Endpoint` acquired various query utility methods.
  * The `{set_}remote()` methods now take/produce `Endpoint`s.
  * `TlsBindable` only accepts single-phase internal interfaces.
  * Bind error messages include candidate endpoint info when possible.
  * The warning message when a secret key is not configured now includes
    information about its effect on private cookies.

Internal changes include:

  * Config module tests were moved to `config/tests.rs`.
  * The cancellable I/O implementation was significantly simplified.
  * The `TripWire` implementation was simplified.
  * Individual shutdown stages can now be awaited on via `Stages`.
  * The `Shield` implementation was simplified.

Resolves #2723.
This commit is contained in:
Sergio Benitez 2024-03-18 20:19:00 -07:00
parent 50c44e8fdc
commit 1619bbbddc
49 changed files with 1535 additions and 1043 deletions

View File

@ -82,7 +82,7 @@ fn client(routes: Vec<Route>) -> Client {
profile: Config::RELEASE_PROFILE,
log_level: rocket::config::LogLevel::Off,
cli_colors: config::CliColors::Never,
shutdown: config::Shutdown {
shutdown: config::ShutdownConfig {
ctrlc: false,
#[cfg(unix)]
signals: HashSet::new(),

View File

@ -12,6 +12,7 @@ mod sqlite_shutdown_test {
let options = map!["url" => ":memory:"];
let config = Figment::from(rocket::Config::debug_default())
.merge(("port", 0))
.merge(("databases", map!["test" => &options]));
rocket::custom(config).attach(Pool::fairing())

View File

@ -22,6 +22,7 @@ all-features = true
[features]
default = ["http2", "tokio-macros"]
http2 = ["hyper/http2", "hyper-util/http2"]
http3-preview = ["s2n-quic", "s2n-quic-h3", "tls"]
secrets = ["cookie/private", "cookie/key-expansion"]
json = ["serde_json"]
msgpack = ["rmp-serde"]
@ -76,8 +77,7 @@ futures = { version = "0.3.30", default-features = false, features = ["std"] }
state = "0.6"
[dependencies.hyper-util]
git = "https://github.com/SergioBenitez/hyper-util.git"
branch = "fix-readversion"
version = "0.1.3"
default-features = false
features = ["http1", "server", "tokio"]
@ -99,6 +99,16 @@ version = "0.6.0-dev"
path = "../http"
features = ["serde"]
[dependencies.s2n-quic]
version = "1.32"
default-features = false
features = ["provider-address-token-default", "provider-tls-rustls"]
optional = true
[dependencies.s2n-quic-h3]
git = "https://github.com/SergioBenitez/s2n-quic-h3.git"
optional = true
[target.'cfg(unix)'.dependencies]
libc = "0.2.149"

View File

@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use yansi::{Paint, Style, Color::Primary};
use crate::log::PaintExt;
use crate::config::{LogLevel, Shutdown, Ident, CliColors};
use crate::config::{LogLevel, ShutdownConfig, Ident, CliColors};
use crate::request::{self, Request, FromRequest};
use crate::http::uncased::Uncased;
use crate::data::Limits;
@ -120,8 +120,8 @@ pub struct Config {
#[cfg_attr(nightly, doc(cfg(feature = "secrets")))]
#[serde(serialize_with = "SecretKey::serialize_zero")]
pub secret_key: SecretKey,
/// Graceful shutdown configuration. **(default: [`Shutdown::default()`])**
pub shutdown: Shutdown,
/// Graceful shutdown configuration. **(default: [`ShutdownConfig::default()`])**
pub shutdown: ShutdownConfig,
/// Max level to log. **(default: _debug_ `normal` / _release_ `critical`)**
pub log_level: LogLevel,
/// Whether to use colors and emoji when logging. **(default:
@ -200,7 +200,7 @@ impl Config {
keep_alive: 5,
#[cfg(feature = "secrets")]
secret_key: SecretKey::zero(),
shutdown: Shutdown::default(),
shutdown: ShutdownConfig::default(),
log_level: LogLevel::Normal,
cli_colors: CliColors::Auto,
__non_exhaustive: (),
@ -408,9 +408,10 @@ impl Config {
#[cfg(feature = "secrets")] {
launch_meta_!("secret key: {}", self.secret_key.paint(VAL));
if !self.secret_key.is_provided() {
warn!("secrets enabled without a stable `secret_key`");
launch_meta_!("disable `secrets` feature or configure a `secret_key`");
launch_meta_!("this becomes an {} in non-debug profiles", "error".red());
warn!("secrets enabled without configuring a stable `secret_key`");
warn_!("private/signed cookies will become unreadable after restarting");
launch_meta_!("disable the `secrets` feature or configure a `secret_key`");
launch_meta_!("this becomes a {} in non-debug profiles", "hard error".red());
}
}
}

View File

@ -113,422 +113,35 @@
#[macro_use]
mod ident;
mod config;
mod shutdown;
mod cli_colors;
mod http_header;
#[cfg(test)]
mod tests;
pub use ident::Ident;
pub use config::Config;
pub use cli_colors::CliColors;
pub use crate::log::LogLevel;
pub use crate::shutdown::ShutdownConfig;
#[cfg(feature = "tls")]
pub use crate::tls::TlsConfig;
#[cfg(feature = "mtls")]
pub use crate::mtls::MtlsConfig;
#[cfg(feature = "secrets")]
mod secret_key;
#[doc(hidden)]
pub use config::{pretty_print_error, bail_with_config_error};
#[cfg(unix)]
pub use crate::shutdown::Sig;
pub use config::Config;
pub use crate::log::LogLevel;
pub use shutdown::Shutdown;
pub use ident::Ident;
pub use cli_colors::CliColors;
#[cfg(unix)]
pub use crate::listener::unix::UdsConfig;
#[cfg(feature = "secrets")]
pub use secret_key::SecretKey;
#[cfg(unix)]
pub use shutdown::Sig;
#[cfg(test)]
mod tests {
use figment::{Figment, Profile};
use pretty_assertions::assert_eq;
use crate::log::LogLevel;
use crate::data::{Limits, ToByteUnit};
use crate::config::{Config, CliColors};
#[test]
fn test_figment_is_default() {
figment::Jail::expect_with(|_| {
let mut default: Config = Config::figment().extract().unwrap();
default.profile = Config::default().profile;
assert_eq!(default, Config::default());
Ok(())
});
}
#[test]
fn test_default_round_trip() {
figment::Jail::expect_with(|_| {
let original = Config::figment();
let roundtrip = Figment::from(Config::from(&original));
for figment in &[original, roundtrip] {
let config = Config::from(figment);
assert_eq!(config, Config::default());
}
Ok(())
});
}
#[test]
fn test_profile_env() {
figment::Jail::expect_with(|jail| {
jail.set_env("ROCKET_PROFILE", "debug");
let figment = Config::figment();
assert_eq!(figment.profile(), "debug");
jail.set_env("ROCKET_PROFILE", "release");
let figment = Config::figment();
assert_eq!(figment.profile(), "release");
jail.set_env("ROCKET_PROFILE", "random");
let figment = Config::figment();
assert_eq!(figment.profile(), "random");
Ok(())
});
}
#[test]
fn test_toml_file() {
figment::Jail::expect_with(|jail| {
jail.create_file("Rocket.toml", r#"
[default]
ident = "Something Cool"
workers = 20
keep_alive = 10
log_level = "off"
cli_colors = 0
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config, Config {
workers: 20,
ident: ident!("Something Cool"),
keep_alive: 10,
log_level: LogLevel::Off,
cli_colors: CliColors::Never,
..Config::default()
});
jail.create_file("Rocket.toml", r#"
[global]
ident = "Something Else Cool"
workers = 20
keep_alive = 10
log_level = "off"
cli_colors = 0
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config, Config {
workers: 20,
ident: ident!("Something Else Cool"),
keep_alive: 10,
log_level: LogLevel::Off,
cli_colors: CliColors::Never,
..Config::default()
});
jail.set_env("ROCKET_CONFIG", "Other.toml");
jail.create_file("Other.toml", r#"
[default]
workers = 20
keep_alive = 10
log_level = "off"
cli_colors = 0
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config, Config {
workers: 20,
keep_alive: 10,
log_level: LogLevel::Off,
cli_colors: CliColors::Never,
..Config::default()
});
Ok(())
});
}
#[test]
fn test_cli_colors() {
figment::Jail::expect_with(|jail| {
jail.create_file("Rocket.toml", r#"
[default]
cli_colors = "never"
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Never);
jail.create_file("Rocket.toml", r#"
[default]
cli_colors = "auto"
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Auto);
jail.create_file("Rocket.toml", r#"
[default]
cli_colors = "always"
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Always);
jail.create_file("Rocket.toml", r#"
[default]
cli_colors = true
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Auto);
jail.create_file("Rocket.toml", r#"
[default]
cli_colors = false
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Never);
jail.create_file("Rocket.toml", r#"[default]"#)?;
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Auto);
jail.create_file("Rocket.toml", r#"
[default]
cli_colors = 1
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Auto);
jail.create_file("Rocket.toml", r#"
[default]
cli_colors = 0
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Never);
jail.set_env("ROCKET_CLI_COLORS", 1);
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Auto);
jail.set_env("ROCKET_CLI_COLORS", 0);
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Never);
jail.set_env("ROCKET_CLI_COLORS", true);
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Auto);
jail.set_env("ROCKET_CLI_COLORS", false);
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Never);
jail.set_env("ROCKET_CLI_COLORS", "always");
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Always);
jail.set_env("ROCKET_CLI_COLORS", "NEveR");
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Never);
jail.set_env("ROCKET_CLI_COLORS", "auTO");
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Auto);
Ok(())
})
}
#[test]
fn test_profiles_merge() {
figment::Jail::expect_with(|jail| {
jail.create_file("Rocket.toml", r#"
[default.limits]
stream = "50kb"
[global]
limits = { forms = "2kb" }
[debug.limits]
file = "100kb"
"#)?;
jail.set_env("ROCKET_PROFILE", "unknown");
let config = Config::from(Config::figment());
assert_eq!(config, Config {
profile: Profile::const_new("unknown"),
limits: Limits::default()
.limit("stream", 50.kilobytes())
.limit("forms", 2.kilobytes()),
..Config::default()
});
jail.set_env("ROCKET_PROFILE", "debug");
let config = Config::from(Config::figment());
assert_eq!(config, Config {
profile: Profile::const_new("debug"),
limits: Limits::default()
.limit("stream", 50.kilobytes())
.limit("forms", 2.kilobytes())
.limit("file", 100.kilobytes()),
..Config::default()
});
Ok(())
});
}
#[test]
fn test_env_vars_merge() {
use crate::config::{Ident, Shutdown};
figment::Jail::expect_with(|jail| {
jail.set_env("ROCKET_KEEP_ALIVE", 9999);
let config = Config::from(Config::figment());
assert_eq!(config, Config {
keep_alive: 9999,
..Config::default()
});
jail.set_env("ROCKET_SHUTDOWN", r#"{grace=7}"#);
let first_figment = Config::figment();
jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=10}"#);
let prev_figment = Config::figment().join(&first_figment);
let config = Config::from(&prev_figment);
assert_eq!(config, Config {
keep_alive: 9999,
shutdown: Shutdown { grace: 7, mercy: 10, ..Default::default() },
..Config::default()
});
jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=20}"#);
let config = Config::from(Config::figment().join(&prev_figment));
assert_eq!(config, Config {
keep_alive: 9999,
shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() },
..Config::default()
});
jail.set_env("ROCKET_LIMITS", r#"{stream=100kiB}"#);
let config = Config::from(Config::figment().join(&prev_figment));
assert_eq!(config, Config {
keep_alive: 9999,
shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() },
limits: Limits::default().limit("stream", 100.kibibytes()),
..Config::default()
});
jail.set_env("ROCKET_IDENT", false);
let config = Config::from(Config::figment().join(&prev_figment));
assert_eq!(config, Config {
keep_alive: 9999,
shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() },
limits: Limits::default().limit("stream", 100.kibibytes()),
ident: Ident::none(),
..Config::default()
});
Ok(())
});
}
#[test]
fn test_precedence() {
figment::Jail::expect_with(|jail| {
jail.create_file("Rocket.toml", r#"
[global.limits]
forms = "1mib"
stream = "50kb"
file = "100kb"
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config, Config {
limits: Limits::default()
.limit("forms", 1.mebibytes())
.limit("stream", 50.kilobytes())
.limit("file", 100.kilobytes()),
..Config::default()
});
jail.set_env("ROCKET_LIMITS", r#"{stream=3MiB,capture=2MiB}"#);
let config = Config::from(Config::figment());
assert_eq!(config, Config {
limits: Limits::default()
.limit("file", 100.kilobytes())
.limit("forms", 1.mebibytes())
.limit("stream", 3.mebibytes())
.limit("capture", 2.mebibytes()),
..Config::default()
});
jail.set_env("ROCKET_PROFILE", "foo");
let val: Result<String, _> = Config::figment().extract_inner("profile");
assert!(val.is_err());
Ok(())
});
}
#[test]
#[cfg(feature = "secrets")]
#[should_panic]
fn test_err_on_non_debug_and_no_secret_key() {
figment::Jail::expect_with(|jail| {
jail.set_env("ROCKET_PROFILE", "release");
let rocket = crate::custom(Config::figment());
let _result = crate::local::blocking::Client::untracked(rocket);
Ok(())
});
}
#[test]
#[cfg(feature = "secrets")]
#[should_panic]
fn test_err_on_non_debug2_and_no_secret_key() {
figment::Jail::expect_with(|jail| {
jail.set_env("ROCKET_PROFILE", "boop");
let rocket = crate::custom(Config::figment());
let _result = crate::local::blocking::Client::tracked(rocket);
Ok(())
});
}
#[test]
fn test_no_err_on_debug_and_no_secret_key() {
figment::Jail::expect_with(|jail| {
jail.set_env("ROCKET_PROFILE", "debug");
let figment = Config::figment();
assert!(crate::local::blocking::Client::untracked(crate::custom(&figment)).is_ok());
crate::async_main(async {
let rocket = crate::custom(&figment);
assert!(crate::local::asynchronous::Client::tracked(rocket).await.is_ok());
});
Ok(())
});
}
#[test]
fn test_no_err_on_release_and_custom_secret_key() {
figment::Jail::expect_with(|jail| {
jail.set_env("ROCKET_PROFILE", "release");
let key = "Bx4Gb+aSIfuoEyMHD4DvNs92+wmzfQK98qc6MiwyPY4=";
let figment = Config::figment().merge(("secret_key", key));
assert!(crate::local::blocking::Client::tracked(crate::custom(&figment)).is_ok());
crate::async_main(async {
let rocket = crate::custom(&figment);
assert!(crate::local::asynchronous::Client::untracked(rocket).await.is_ok());
});
Ok(())
});
}
}
#[doc(hidden)]
pub use config::{pretty_print_error, bail_with_config_error};

View File

@ -0,0 +1,394 @@
use figment::{Figment, Profile};
use pretty_assertions::assert_eq;
use crate::log::LogLevel;
use crate::data::{Limits, ToByteUnit};
use crate::config::{Config, CliColors};
#[test]
fn test_figment_is_default() {
figment::Jail::expect_with(|_| {
let mut default: Config = Config::figment().extract().unwrap();
default.profile = Config::default().profile;
assert_eq!(default, Config::default());
Ok(())
});
}
#[test]
fn test_default_round_trip() {
figment::Jail::expect_with(|_| {
let original = Config::figment();
let roundtrip = Figment::from(Config::from(&original));
for figment in &[original, roundtrip] {
let config = Config::from(figment);
assert_eq!(config, Config::default());
}
Ok(())
});
}
#[test]
fn test_profile_env() {
figment::Jail::expect_with(|jail| {
jail.set_env("ROCKET_PROFILE", "debug");
let figment = Config::figment();
assert_eq!(figment.profile(), "debug");
jail.set_env("ROCKET_PROFILE", "release");
let figment = Config::figment();
assert_eq!(figment.profile(), "release");
jail.set_env("ROCKET_PROFILE", "random");
let figment = Config::figment();
assert_eq!(figment.profile(), "random");
Ok(())
});
}
#[test]
fn test_toml_file() {
figment::Jail::expect_with(|jail| {
jail.create_file("Rocket.toml", r#"
[default]
ident = "Something Cool"
workers = 20
keep_alive = 10
log_level = "off"
cli_colors = 0
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config, Config {
workers: 20,
ident: ident!("Something Cool"),
keep_alive: 10,
log_level: LogLevel::Off,
cli_colors: CliColors::Never,
..Config::default()
});
jail.create_file("Rocket.toml", r#"
[global]
ident = "Something Else Cool"
workers = 20
keep_alive = 10
log_level = "off"
cli_colors = 0
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config, Config {
workers: 20,
ident: ident!("Something Else Cool"),
keep_alive: 10,
log_level: LogLevel::Off,
cli_colors: CliColors::Never,
..Config::default()
});
jail.set_env("ROCKET_CONFIG", "Other.toml");
jail.create_file("Other.toml", r#"
[default]
workers = 20
keep_alive = 10
log_level = "off"
cli_colors = 0
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config, Config {
workers: 20,
keep_alive: 10,
log_level: LogLevel::Off,
cli_colors: CliColors::Never,
..Config::default()
});
Ok(())
});
}
#[test]
fn test_cli_colors() {
figment::Jail::expect_with(|jail| {
jail.create_file("Rocket.toml", r#"
[default]
cli_colors = "never"
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Never);
jail.create_file("Rocket.toml", r#"
[default]
cli_colors = "auto"
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Auto);
jail.create_file("Rocket.toml", r#"
[default]
cli_colors = "always"
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Always);
jail.create_file("Rocket.toml", r#"
[default]
cli_colors = true
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Auto);
jail.create_file("Rocket.toml", r#"
[default]
cli_colors = false
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Never);
jail.create_file("Rocket.toml", r#"[default]"#)?;
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Auto);
jail.create_file("Rocket.toml", r#"
[default]
cli_colors = 1
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Auto);
jail.create_file("Rocket.toml", r#"
[default]
cli_colors = 0
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Never);
jail.set_env("ROCKET_CLI_COLORS", 1);
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Auto);
jail.set_env("ROCKET_CLI_COLORS", 0);
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Never);
jail.set_env("ROCKET_CLI_COLORS", true);
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Auto);
jail.set_env("ROCKET_CLI_COLORS", false);
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Never);
jail.set_env("ROCKET_CLI_COLORS", "always");
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Always);
jail.set_env("ROCKET_CLI_COLORS", "NEveR");
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Never);
jail.set_env("ROCKET_CLI_COLORS", "auTO");
let config = Config::from(Config::figment());
assert_eq!(config.cli_colors, CliColors::Auto);
Ok(())
})
}
#[test]
fn test_profiles_merge() {
figment::Jail::expect_with(|jail| {
jail.create_file("Rocket.toml", r#"
[default.limits]
stream = "50kb"
[global]
limits = { forms = "2kb" }
[debug.limits]
file = "100kb"
"#)?;
jail.set_env("ROCKET_PROFILE", "unknown");
let config = Config::from(Config::figment());
assert_eq!(config, Config {
profile: Profile::const_new("unknown"),
limits: Limits::default()
.limit("stream", 50.kilobytes())
.limit("forms", 2.kilobytes()),
..Config::default()
});
jail.set_env("ROCKET_PROFILE", "debug");
let config = Config::from(Config::figment());
assert_eq!(config, Config {
profile: Profile::const_new("debug"),
limits: Limits::default()
.limit("stream", 50.kilobytes())
.limit("forms", 2.kilobytes())
.limit("file", 100.kilobytes()),
..Config::default()
});
Ok(())
});
}
#[test]
fn test_env_vars_merge() {
use crate::config::{Ident, ShutdownConfig};
figment::Jail::expect_with(|jail| {
jail.set_env("ROCKET_KEEP_ALIVE", 9999);
let config = Config::from(Config::figment());
assert_eq!(config, Config {
keep_alive: 9999,
..Config::default()
});
jail.set_env("ROCKET_SHUTDOWN", r#"{grace=7}"#);
let first_figment = Config::figment();
jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=10}"#);
let prev_figment = Config::figment().join(&first_figment);
let config = Config::from(&prev_figment);
assert_eq!(config, Config {
keep_alive: 9999,
shutdown: ShutdownConfig { grace: 7, mercy: 10, ..Default::default() },
..Config::default()
});
jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=20}"#);
let config = Config::from(Config::figment().join(&prev_figment));
assert_eq!(config, Config {
keep_alive: 9999,
shutdown: ShutdownConfig { grace: 7, mercy: 20, ..Default::default() },
..Config::default()
});
jail.set_env("ROCKET_LIMITS", r#"{stream=100kiB}"#);
let config = Config::from(Config::figment().join(&prev_figment));
assert_eq!(config, Config {
keep_alive: 9999,
shutdown: ShutdownConfig { grace: 7, mercy: 20, ..Default::default() },
limits: Limits::default().limit("stream", 100.kibibytes()),
..Config::default()
});
jail.set_env("ROCKET_IDENT", false);
let config = Config::from(Config::figment().join(&prev_figment));
assert_eq!(config, Config {
keep_alive: 9999,
shutdown: ShutdownConfig { grace: 7, mercy: 20, ..Default::default() },
limits: Limits::default().limit("stream", 100.kibibytes()),
ident: Ident::none(),
..Config::default()
});
Ok(())
});
}
#[test]
fn test_precedence() {
figment::Jail::expect_with(|jail| {
jail.create_file("Rocket.toml", r#"
[global.limits]
forms = "1mib"
stream = "50kb"
file = "100kb"
"#)?;
let config = Config::from(Config::figment());
assert_eq!(config, Config {
limits: Limits::default()
.limit("forms", 1.mebibytes())
.limit("stream", 50.kilobytes())
.limit("file", 100.kilobytes()),
..Config::default()
});
jail.set_env("ROCKET_LIMITS", r#"{stream=3MiB,capture=2MiB}"#);
let config = Config::from(Config::figment());
assert_eq!(config, Config {
limits: Limits::default()
.limit("file", 100.kilobytes())
.limit("forms", 1.mebibytes())
.limit("stream", 3.mebibytes())
.limit("capture", 2.mebibytes()),
..Config::default()
});
jail.set_env("ROCKET_PROFILE", "foo");
let val: Result<String, _> = Config::figment().extract_inner("profile");
assert!(val.is_err());
Ok(())
});
}
#[test]
#[cfg(feature = "secrets")]
#[should_panic]
fn test_err_on_non_debug_and_no_secret_key() {
figment::Jail::expect_with(|jail| {
jail.set_env("ROCKET_PROFILE", "release");
let rocket = crate::custom(Config::figment());
let _result = crate::local::blocking::Client::untracked(rocket);
Ok(())
});
}
#[test]
#[cfg(feature = "secrets")]
#[should_panic]
fn test_err_on_non_debug2_and_no_secret_key() {
figment::Jail::expect_with(|jail| {
jail.set_env("ROCKET_PROFILE", "boop");
let rocket = crate::custom(Config::figment());
let _result = crate::local::blocking::Client::tracked(rocket);
Ok(())
});
}
#[test]
fn test_no_err_on_debug_and_no_secret_key() {
figment::Jail::expect_with(|jail| {
jail.set_env("ROCKET_PROFILE", "debug");
let figment = Config::figment();
assert!(crate::local::blocking::Client::untracked(crate::custom(&figment)).is_ok());
crate::async_main(async {
let rocket = crate::custom(&figment);
assert!(crate::local::asynchronous::Client::tracked(rocket).await.is_ok());
});
Ok(())
});
}
#[test]
fn test_no_err_on_release_and_custom_secret_key() {
figment::Jail::expect_with(|jail| {
jail.set_env("ROCKET_PROFILE", "release");
let key = "Bx4Gb+aSIfuoEyMHD4DvNs92+wmzfQK98qc6MiwyPY4=";
let figment = Config::figment().merge(("secret_key", key));
assert!(crate::local::blocking::Client::tracked(crate::custom(&figment)).is_ok());
crate::async_main(async {
let rocket = crate::custom(&figment);
assert!(crate::local::asynchronous::Client::untracked(rocket).await.is_ok());
});
Ok(())
});
}

View File

@ -68,7 +68,9 @@ pub type RawReader<'r> = StreamReader<RawStream<'r>, Bytes>;
/// Raw underlying data stream.
pub enum RawStream<'r> {
Empty,
Body(&'r mut HyperBody),
Body(HyperBody),
#[cfg(feature = "http3-preview")]
H3Body(crate::listener::Cancellable<crate::listener::quic::QuicRx>),
Multipart(multer::Field<'r>),
}
@ -343,7 +345,9 @@ impl Stream for RawStream<'_> {
.poll_frame(cx)
.map_ok(|frame| frame.into_data().unwrap_or_else(|_| Bytes::new()))
.map_err(io::Error::other)
}
},
#[cfg(feature = "http3-preview")]
RawStream::H3Body(stream) => Pin::new(stream).poll_next(cx),
RawStream::Multipart(s) => Pin::new(s).poll_next(cx).map_err(io::Error::other),
RawStream::Empty => Poll::Ready(None),
}
@ -356,6 +360,8 @@ impl Stream for RawStream<'_> {
let (lower, upper) = (hint.lower(), hint.upper());
(lower as usize, upper.map(|x| x as usize))
},
#[cfg(feature = "http3-preview")]
RawStream::H3Body(_) => (0, Some(0)),
RawStream::Multipart(mp) => mp.size_hint(),
RawStream::Empty => (0, Some(0)),
}
@ -367,17 +373,26 @@ impl std::fmt::Display for RawStream<'_> {
match self {
RawStream::Empty => f.write_str("empty stream"),
RawStream::Body(_) => f.write_str("request body"),
#[cfg(feature = "http3-preview")]
RawStream::H3Body(_) => f.write_str("http3 quic stream"),
RawStream::Multipart(_) => f.write_str("multipart form field"),
}
}
}
impl<'r> From<&'r mut HyperBody> for RawStream<'r> {
fn from(value: &'r mut HyperBody) -> Self {
impl<'r> From<HyperBody> for RawStream<'r> {
fn from(value: HyperBody) -> Self {
Self::Body(value)
}
}
#[cfg(feature = "http3-preview")]
impl<'r> From<crate::listener::Cancellable<crate::listener::quic::QuicRx>> for RawStream<'r> {
fn from(value: crate::listener::Cancellable<crate::listener::quic::QuicRx>) -> Self {
Self::H3Body(value)
}
}
impl<'r> From<multer::Field<'r>> for RawStream<'r> {
fn from(value: multer::Field<'r>) -> Self {
Self::Multipart(value)

View File

@ -18,3 +18,5 @@ pub use self::capped::{N, Capped};
pub use self::io_stream::{IoHandler, IoStream};
pub use ubyte::{ByteUnit, ToByteUnit};
pub use self::transform::{Transform, TransformBuf};
pub(crate) use self::data_stream::RawStream;

View File

@ -6,19 +6,18 @@ use std::task::{Poll, Context};
use futures::future::BoxFuture;
use http::request::Parts;
use hyper::body::Incoming;
use tokio::io::{AsyncRead, ReadBuf};
use crate::data::{Data, IoHandler};
use crate::data::{Data, IoHandler, RawStream};
use crate::{Request, Response, Rocket, Orbit};
// TODO: Magic with trait async fn to get rid of the box pin.
// TODO: Write safety proofs.
macro_rules! static_assert_covariance {
($T:tt) => (
($($T:tt)*) => (
const _: () = {
fn _assert_covariance<'x: 'y, 'y>(x: &'y $T<'x>) -> &'y $T<'y> { x }
fn _assert_covariance<'x: 'y, 'y>(x: &'y $($T)*<'x>) -> &'y $($T)*<'y> { x }
};
)
}
@ -40,7 +39,6 @@ pub struct ErasedResponse {
// XXX: SAFETY: This (dependent) field must come first due to drop order!
response: Response<'static>,
_request: Arc<ErasedRequest>,
_incoming: Box<Incoming>,
}
impl Drop for ErasedResponse {
@ -79,10 +77,9 @@ impl ErasedRequest {
ErasedRequest { _rocket: rocket, _parts: parts, request, }
}
pub async fn into_response<T: Send + Sync + 'static>(
pub async fn into_response<T, D>(
self,
incoming: Incoming,
data_builder: impl for<'r> FnOnce(&'r mut Incoming) -> Data<'r>,
raw_stream: D,
preprocess: impl for<'r, 'x> FnOnce(
&'r Rocket<Orbit>,
&'r mut Request<'x>,
@ -94,14 +91,11 @@ impl ErasedRequest {
&'r Request<'r>,
Data<'r>
) -> BoxFuture<'r, Response<'r>>,
) -> ErasedResponse {
let mut incoming = Box::new(incoming);
let mut data: Data<'_> = {
let incoming: &mut Incoming = &mut *incoming;
let incoming: &'static mut Incoming = unsafe { transmute(incoming) };
data_builder(incoming)
};
) -> ErasedResponse
where T: Send + Sync + 'static,
D: for<'r> Into<RawStream<'r>>
{
let mut data: Data<'_> = Data::from(raw_stream);
let mut parent = Arc::new(self);
let token: T = {
let parent: &mut ErasedRequest = Arc::get_mut(&mut parent).unwrap();
@ -122,7 +116,6 @@ impl ErasedRequest {
ErasedResponse {
_request: parent,
_incoming: incoming,
response: response,
}
}

View File

@ -7,7 +7,8 @@ use std::error::Error as StdError;
use yansi::Paint;
use figment::Profile;
use crate::{Rocket, Orbit};
use crate::listener::Endpoint;
use crate::{Ignite, Orbit, Rocket};
/// An error that occurs during launch.
///
@ -74,8 +75,8 @@ pub struct Error {
#[derive(Debug)]
#[non_exhaustive]
pub enum ErrorKind {
/// Binding to the network interface failed.
Bind(Box<dyn StdError + Send>),
/// Binding to the network interface at `.0` failed with error `.1`.
Bind(Option<Endpoint>, Box<dyn StdError + Send>),
/// An I/O error occurred during launch.
Io(io::Error),
/// A valid [`Config`](crate::Config) could not be extracted from the
@ -89,6 +90,11 @@ pub enum ErrorKind {
SentinelAborts(Vec<crate::sentinel::Sentry>),
/// The configuration profile is not debug but no secret key is configured.
InsecureSecretKey(Profile),
/// Liftoff failed. Contains the Rocket instance that failed to shutdown.
Liftoff(
Result<Rocket<Ignite>, Arc<Rocket<Orbit>>>,
Box<dyn StdError + Send + 'static>
),
/// Shutdown failed. Contains the Rocket instance that failed to shutdown.
Shutdown(Arc<Rocket<Orbit>>),
}
@ -171,8 +177,12 @@ impl Error {
pub fn pretty_print(&self) -> &'static str {
self.mark_handled();
match self.kind() {
ErrorKind::Bind(ref e) => {
error!("Binding to the network interface failed.");
ErrorKind::Bind(ref a, ref e) => {
match a {
Some(a) => error!("Binding to {} failed.", a.primary().underline()),
None => error!("Binding to network interface failed."),
}
info_!("{}", e);
"aborting due to bind error"
}
@ -225,6 +235,11 @@ impl Error {
"aborting due to sentinel-triggered abort(s)"
}
ErrorKind::Liftoff(_, error) => {
error!("Rocket liftoff failed due to panicking liftoff fairing(s).");
error_!("{error}");
"aborting due to failed liftoff"
}
ErrorKind::Shutdown(_) => {
error!("Rocket failed to shutdown gracefully.");
"aborting due to failed shutdown"
@ -239,13 +254,14 @@ impl fmt::Display for ErrorKind {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ErrorKind::Bind(e) => write!(f, "binding failed: {e}"),
ErrorKind::Bind(_, e) => write!(f, "binding failed: {e}"),
ErrorKind::Io(e) => write!(f, "I/O error: {e}"),
ErrorKind::Collisions(_) => "collisions detected".fmt(f),
ErrorKind::FailedFairings(_) => "launch fairing(s) failed".fmt(f),
ErrorKind::InsecureSecretKey(_) => "insecure secret key config".fmt(f),
ErrorKind::Config(_) => "failed to extract configuration".fmt(f),
ErrorKind::SentinelAborts(_) => "sentinel(s) aborted".fmt(f),
ErrorKind::Liftoff(_, _) => "liftoff failed".fmt(f),
ErrorKind::Shutdown(_) => "shutdown failed".fmt(f),
}
}
@ -293,40 +309,45 @@ impl fmt::Display for Empty {
impl StdError for Empty { }
/// Log an error that occurs during request processing
pub(crate) fn log_server_error(error: &Box<dyn StdError + Send + Sync>) {
#[track_caller]
pub(crate) fn log_server_error(error: &(dyn StdError + 'static)) {
struct ServerError<'a>(&'a (dyn StdError + 'static));
impl fmt::Display for ServerError<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let error = &self.0;
if let Some(e) = error.downcast_ref::<hyper::Error>() {
write!(f, "request processing failed: {e}")?;
write!(f, "request failed: {e}")?;
} else if let Some(e) = error.downcast_ref::<io::Error>() {
write!(f, "connection I/O error: ")?;
write!(f, "connection error: ")?;
match e.kind() {
io::ErrorKind::NotConnected => write!(f, "remote disconnected")?,
io::ErrorKind::UnexpectedEof => write!(f, "remote sent early eof")?,
io::ErrorKind::ConnectionReset
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::BrokenPipe => write!(f, "terminated by remote")?,
| io::ErrorKind::ConnectionAborted => write!(f, "terminated by remote")?,
_ => write!(f, "{e}")?,
}
} else {
write!(f, "http server error: {error}")?;
}
if let Some(e) = error.source() {
write!(f, " ({})", ServerError(e))?;
}
Ok(())
}
}
let mut error: &(dyn StdError + 'static) = &*error;
if error.downcast_ref::<hyper::Error>().is_some() {
warn!("{}", ServerError(&**error))
warn!("{}", ServerError(error));
while let Some(source) = error.source() {
error = source;
warn_!("{}", ServerError(error));
}
} else {
error!("{}", ServerError(&**error))
error!("{}", ServerError(error));
while let Some(source) = error.source() {
error = source;
error_!("{}", ServerError(error));
}
}
}

View File

@ -187,7 +187,7 @@ impl AdHoc {
/// Constructs an `AdHoc` shutdown fairing named `name`. The function `f`
/// will be called by Rocket when [shutdown is triggered].
///
/// [shutdown is triggered]: crate::config::Shutdown#triggers
/// [shutdown is triggered]: crate::config::ShutdownConfig#triggers
///
/// # Example
///

View File

@ -191,8 +191,8 @@ pub type Result<T = Rocket<Build>, E = Rocket<Build>> = std::result::Result<T, E
/// ***Note: Shutdown fairings are only run during testing if the `Client`
/// is terminated using [`Client::terminate()`].***
///
/// [shutdown is triggered]: crate::config::Shutdown#triggers
/// [grace and mercy periods]: crate::config::Shutdown#summary
/// [shutdown is triggered]: crate::config::ShutdownConfig#triggers
/// [grace and mercy periods]: crate::config::ShutdownConfig#summary
/// [`Client::terminate()`]: crate::local::blocking::Client::terminate()
///
/// # Singletons
@ -525,7 +525,7 @@ pub trait Fairing: Send + Sync + Any + 'static {
/// is in the `kind` field of the `Info` structure for this fairing. The
/// `Rocket` parameter corresponds to the running application.
///
/// [shutdown is triggered]: crate::config::Shutdown#triggers
/// [shutdown is triggered]: crate::config::ShutdownConfig#triggers
///
/// ## Default Implementation
///

View File

@ -176,7 +176,7 @@ impl<'a> CookieJar<'a> {
ops: Mutex::new(Vec::new()),
state: CookieState {
// This is updated dynamically when headers are received.
secure: rocket.endpoint().is_tls(),
secure: rocket.endpoints().all(|e| e.is_tls()),
config: rocket.config(),
}
}

View File

@ -61,15 +61,17 @@
//! To avoid compiling unused dependencies, Rocket gates certain features. With
//! the exception of `http2`, all are disabled by default:
//!
//! | Feature | Description |
//! |-----------|---------------------------------------------------------|
//! | `secrets` | Support for authenticated, encrypted [private cookies]. |
//! | `tls` | Support for [TLS] encrypted connections. |
//! | `mtls` | Support for verified clients via [mutual TLS]. |
//! | `http2` | Support for HTTP/2 (enabled by default). |
//! | `json` | Support for [JSON (de)serialization]. |
//! | `msgpack` | Support for [MessagePack (de)serialization]. |
//! | `uuid` | Support for [UUID value parsing and (de)serialization]. |
//! | Feature | Description |
//! |-----------------|---------------------------------------------------------|
//! | `secrets` | Support for authenticated, encrypted [private cookies]. |
//! | `tls` | Support for [TLS] encrypted connections. |
//! | `mtls` | Support for verified clients via [mutual TLS]. |
//! | `http2` | Support for HTTP/2 (enabled by default). |
//! | `json` | Support for [JSON (de)serialization]. |
//! | `msgpack` | Support for [MessagePack (de)serialization]. |
//! | `uuid` | Support for [UUID value parsing and (de)serialization]. |
//! | `tokio-macros` | Enables the `macros` feature in the exported `tokio` |
//! | `http3-preview` | Experimental preview support for [HTTP/3]. |
//!
//! Disabled features can be selectively enabled in `Cargo.toml`:
//!
@ -91,6 +93,7 @@
//! [private cookies]: https://rocket.rs/master/guide/requests/#private-cookies
//! [TLS]: https://rocket.rs/master/guide/configuration/#tls
//! [mutual TLS]: crate::mtls
//! [HTTP/3]: crate::listener::quic
//!
//! ## Configuration
//!
@ -143,6 +146,7 @@ pub mod shield;
pub mod fs;
pub mod http;
pub mod listener;
pub mod shutdown;
#[cfg(feature = "tls")]
#[cfg_attr(nightly, doc(cfg(feature = "tls")))]
pub mod tls;
@ -151,7 +155,6 @@ pub mod tls;
pub mod mtls;
mod util;
mod shutdown;
mod server;
mod lifecycle;
mod state;

View File

@ -1,6 +1,7 @@
use std::io;
use futures::TryFutureExt;
use crate::listener::Listener;
use crate::listener::{Listener, Endpoint};
pub trait Bindable: Sized {
type Listener: Listener + 'static;
@ -8,6 +9,9 @@ pub trait Bindable: Sized {
type Error: std::error::Error + Send + 'static;
async fn bind(self) -> Result<Self::Listener, Self::Error>;
/// The endpoint that `self` binds on.
fn candidate_endpoint(&self) -> io::Result<Endpoint>;
}
impl<L: Listener + 'static> Bindable for L {
@ -18,6 +22,10 @@ impl<L: Listener + 'static> Bindable for L {
async fn bind(self) -> Result<Self::Listener, Self::Error> {
Ok(self)
}
fn candidate_endpoint(&self) -> io::Result<Endpoint> {
L::endpoint(self)
}
}
impl<A: Bindable, B: Bindable> Bindable for either::Either<A, B> {
@ -37,4 +45,8 @@ impl<A: Bindable, B: Bindable> Bindable for either::Either<A, B> {
.await,
}
}
fn candidate_endpoint(&self) -> io::Result<Endpoint> {
either::for_both!(self, a => a.candidate_endpoint())
}
}

View File

@ -52,7 +52,7 @@ impl<L: Listener + Sync> Listener for Bounced<L> {
self.listener.connect(accept).await
}
fn socket_addr(&self) -> io::Result<Endpoint> {
self.listener.socket_addr()
fn endpoint(&self) -> io::Result<Endpoint> {
self.listener.endpoint()
}
}

View File

@ -1,178 +1,79 @@
use std::io;
use std::time::Duration;
use std::task::{Poll, Context};
use std::pin::Pin;
use tokio::time::{sleep, Sleep};
use futures::Stream;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use futures::{StreamExt, future::{select, Either, Fuse, Future, FutureExt}};
use futures::future::FutureExt;
use pin_project_lite::pin_project;
use crate::{config, Shutdown};
use crate::listener::{Listener, Connection, Certificates, Bounced, Endpoint};
// Rocket wraps all connections in a `CancellableIo` struct, an internal
// structure that gracefully closes I/O when it receives a signal. That signal
// is the `shutdown` future. When the future resolves, `CancellableIo` begins to
// terminate in grace, mercy, and finally force close phases. Since all
// connections are wrapped in `CancellableIo`, this eventually ends all I/O.
//
// At that point, unless a user spawned an infinite, stand-alone task that isn't
// monitoring `Shutdown`, all tasks should resolve. This means that all
// instances of the shared `Arc<Rocket>` are dropped and we can return the owned
// instance of `Rocket`.
//
// Unfortunately, the Hyper `server` future resolves as soon as it has finished
// processing requests without respect for ongoing responses. That is, `server`
// resolves even when there are running tasks that are generating a response.
// So, `server` resolving implies little to nothing about the state of
// connections. As a result, we depend on the timing of grace + mercy + some
// buffer to determine when all connections should be closed, thus all tasks
// should be complete, thus all references to `Arc<Rocket>` should be dropped
// and we can get a unique reference.
pin_project! {
pub struct CancellableListener<F, L> {
pub trigger: F,
#[pin]
pub listener: L,
pub grace: Duration,
pub mercy: Duration,
}
}
use crate::shutdown::Stages;
pin_project! {
/// I/O that can be cancelled when a future `F` resolves.
#[must_use = "futures do nothing unless polled"]
pub struct CancellableIo<F, I> {
pub struct Cancellable<I> {
#[pin]
io: Option<I>,
#[pin]
trigger: Fuse<F>,
stages: Stages,
state: State,
grace: Duration,
mercy: Duration,
}
}
#[derive(Debug)]
enum State {
/// I/O has not been cancelled. Proceed as normal.
/// I/O has not been cancelled. Proceed as normal until `Shutdown`.
Active,
/// I/O has been cancelled. See if we can finish before the timer expires.
Grace(Pin<Box<Sleep>>),
/// Grace period elapsed. Shutdown the connection, waiting for the timer
/// until we force close.
Mercy(Pin<Box<Sleep>>),
/// I/O has been cancelled. Try to finish before `Shutdown`.
Grace,
/// Grace has elapsed. Shutdown connections. After `Shutdown`, force close.
Mercy,
}
pub trait CancellableExt: Sized {
fn cancellable(
self,
trigger: Shutdown,
config: &config::Shutdown
) -> CancellableListener<Shutdown, Self> {
if let Some(mut stream) = config.signal_stream() {
let trigger = trigger.clone();
tokio::spawn(async move {
while let Some(sig) = stream.next().await {
if trigger.0.tripped() {
warn!("Received {}. Shutdown already in progress.", sig);
} else {
warn!("Received {}. Requesting shutdown.", sig);
}
trigger.0.trip();
}
});
};
CancellableListener {
trigger,
listener: self,
grace: config.grace(),
mercy: config.mercy(),
fn cancellable(self, stages: Stages) -> Cancellable<Self> {
Cancellable {
io: Some(self),
state: State::Active,
stages,
}
}
}
impl<L: Listener> CancellableExt for L { }
impl<T> CancellableExt for T { }
fn time_out() -> io::Error {
io::Error::new(io::ErrorKind::TimedOut, "Shutdown grace timed out")
io::Error::new(io::ErrorKind::TimedOut, "shutdown grace period elapsed")
}
fn gone() -> io::Error {
io::Error::new(io::ErrorKind::BrokenPipe, "IO driver has terminated")
io::Error::new(io::ErrorKind::BrokenPipe, "I/O driver terminated")
}
impl<L, F> CancellableListener<F, Bounced<L>>
where L: Listener + Sync,
F: Future + Unpin + Clone + Send + Sync + 'static
{
pub async fn accept_next(&self) -> Option<<Self as Listener>::Accept> {
let next = std::pin::pin!(self.listener.accept_next());
match select(next, self.trigger.clone()).await {
Either::Left((next, _)) => Some(next),
Either::Right(_) => None,
}
}
}
impl<L, F> CancellableListener<F, L>
where L: Listener + Sync,
F: Future + Clone + Send + Sync + 'static
{
fn io<C>(&self, conn: C) -> CancellableIo<F, C> {
CancellableIo {
io: Some(conn),
trigger: self.trigger.clone().fuse(),
state: State::Active,
grace: self.grace,
mercy: self.mercy,
}
}
}
impl<L, F> Listener for CancellableListener<F, L>
where L: Listener + Sync,
F: Future + Clone + Send + Sync + Unpin + 'static
{
type Accept = L::Accept;
type Connection = CancellableIo<F, L::Connection>;
async fn accept(&self) -> io::Result<Self::Accept> {
let accept = std::pin::pin!(self.listener.accept());
match select(accept, self.trigger.clone()).await {
Either::Left((result, _)) => result,
Either::Right(_) => Err(gone()),
}
}
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> {
let conn = std::pin::pin!(self.listener.connect(accept));
match select(conn, self.trigger.clone()).await {
Either::Left((conn, _)) => Ok(self.io(conn?)),
Either::Right(_) => Err(gone()),
}
}
fn socket_addr(&self) -> io::Result<Endpoint> {
self.listener.socket_addr()
}
}
impl<F: Future, I: AsyncWrite> CancellableIo<F, I> {
fn inner(&self) -> Option<&I> {
impl<I: AsyncCancel> Cancellable<I> {
pub fn inner(&self) -> Option<&I> {
self.io.as_ref()
}
}
pub trait AsyncCancel {
fn poll_cancel(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
}
impl<T: AsyncWrite> AsyncCancel for T {
fn poll_cancel(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
<T as AsyncWrite>::poll_shutdown(self, cx)
}
}
impl<I: AsyncCancel> Cancellable<I> {
/// Run `do_io` while connection processing should continue.
fn poll_trigger_then<T>(
pub fn poll_with<T>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll<io::Result<T>>,
) -> Poll<io::Result<T>> {
let mut me = self.as_mut().project();
let me = self.as_mut().project();
let io = match me.io.as_pin_mut() {
Some(io) => io,
None => return Poll::Ready(Err(gone())),
@ -181,29 +82,29 @@ impl<F: Future, I: AsyncWrite> CancellableIo<F, I> {
loop {
match me.state {
State::Active => {
if me.trigger.as_mut().poll(cx).is_ready() {
*me.state = State::Grace(Box::pin(sleep(*me.grace)));
if me.stages.start.poll_unpin(cx).is_ready() {
*me.state = State::Grace;
} else {
return do_io(io, cx);
}
}
State::Grace(timer) => {
if timer.as_mut().poll(cx).is_ready() {
*me.state = State::Mercy(Box::pin(sleep(*me.mercy)));
State::Grace => {
if me.stages.grace.poll_unpin(cx).is_ready() {
*me.state = State::Mercy;
} else {
return do_io(io, cx);
}
}
State::Mercy(timer) => {
if timer.as_mut().poll(cx).is_ready() {
State::Mercy => {
if me.stages.mercy.poll_unpin(cx).is_ready() {
self.project().io.set(None);
return Poll::Ready(Err(time_out()));
} else {
let result = futures::ready!(io.poll_shutdown(cx));
let result = futures::ready!(io.poll_cancel(cx));
self.project().io.set(None);
return match result {
Ok(()) => Poll::Ready(Err(gone())),
Err(e) => Poll::Ready(Err(e)),
Ok(()) => Poll::Ready(Err(gone()))
};
}
},
@ -212,45 +113,45 @@ impl<F: Future, I: AsyncWrite> CancellableIo<F, I> {
}
}
impl<F: Future, I: AsyncRead + AsyncWrite> AsyncRead for CancellableIo<F, I> {
impl<I: AsyncRead + AsyncCancel> AsyncRead for Cancellable<I> {
fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_read(cx, buf))
self.poll_with(cx, |io, cx| io.poll_read(cx, buf))
}
}
impl<F: Future, I: AsyncWrite> AsyncWrite for CancellableIo<F, I> {
impl<I: AsyncWrite> AsyncWrite for Cancellable<I> {
fn poll_write(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write(cx, buf))
self.poll_with(cx, |io, cx| io.poll_write(cx, buf))
}
fn poll_flush(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<()>> {
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_flush(cx))
self.poll_with(cx, |io, cx| io.poll_flush(cx))
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<()>> {
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_shutdown(cx))
self.poll_with(cx, |io, cx| io.poll_shutdown(cx))
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write_vectored(cx, bufs))
self.poll_with(cx, |io, cx| io.poll_write_vectored(cx, bufs))
}
fn is_write_vectored(&self) -> bool {
@ -258,16 +159,16 @@ impl<F: Future, I: AsyncWrite> AsyncWrite for CancellableIo<F, I> {
}
}
impl<F: Future, C: Connection> Connection for CancellableIo<F, C>
where F: Unpin + Send + 'static
{
fn peer_address(&self) -> io::Result<Endpoint> {
self.inner()
.ok_or_else(|| gone())
.and_then(|io| io.peer_address())
}
impl<T, I: Stream<Item = io::Result<T>> + AsyncCancel> Stream for Cancellable<I> {
type Item = I::Item;
fn peer_certificates(&self) -> Option<Certificates<'_>> {
self.inner().and_then(|io| io.peer_certificates())
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use futures::ready;
match ready!(self.poll_with(cx, |io, cx| io.poll_next(cx).map(Ok))) {
Ok(Some(v)) => Poll::Ready(Some(v)),
Ok(None) => Poll::Ready(None),
Err(e) => Poll::Ready(Some(Err(e))),
}
}
}

View File

@ -2,7 +2,6 @@ use std::io;
use std::borrow::Cow;
use tokio_util::either::Either;
use tokio::io::{AsyncRead, AsyncWrite};
use super::Endpoint;
@ -10,8 +9,8 @@ use super::Endpoint;
#[derive(Clone)]
pub struct Certificates<'r>(Cow<'r, [der::CertificateDer<'r>]>);
pub trait Connection: AsyncRead + AsyncWrite + Send + Unpin {
fn peer_address(&self) -> io::Result<Endpoint>;
pub trait Connection: Send + Unpin {
fn endpoint(&self) -> io::Result<Endpoint>;
/// DER-encoded X.509 certificate chain presented by the client, if any.
///
@ -21,21 +20,21 @@ pub trait Connection: AsyncRead + AsyncWrite + Send + Unpin {
///
/// Defaults to an empty vector to indicate that no certificates were
/// presented.
fn peer_certificates(&self) -> Option<Certificates<'_>> { None }
fn certificates(&self) -> Option<Certificates<'_>> { None }
}
impl<A: Connection, B: Connection> Connection for Either<A, B> {
fn peer_address(&self) -> io::Result<Endpoint> {
fn endpoint(&self) -> io::Result<Endpoint> {
match self {
Either::Left(c) => c.peer_address(),
Either::Right(c) => c.peer_address(),
Either::Left(c) => c.endpoint(),
Either::Right(c) => c.endpoint(),
}
}
fn peer_certificates(&self) -> Option<Certificates<'_>> {
fn certificates(&self) -> Option<Certificates<'_>> {
match self {
Either::Left(c) => c.peer_certificates(),
Either::Right(c) => c.peer_certificates(),
Either::Left(c) => c.certificates(),
Either::Right(c) => c.certificates(),
}
}
}

View File

@ -32,15 +32,15 @@ impl DefaultListener {
Ok(BaseBindable::Right(uds))
},
#[cfg(not(unix))]
Endpoint::Unix(_) => {
e@Endpoint::Unix(_) => {
let msg = "Unix domain sockets unavailable on non-unix platforms.";
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(Error::new(ErrorKind::Bind(boxed)))
Err(Error::new(ErrorKind::Bind(Some(e.clone()), boxed)))
},
other => {
let msg = format!("unsupported default listener address: {other}");
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(Error::new(ErrorKind::Bind(boxed)))
Err(Error::new(ErrorKind::Bind(Some(other.clone()), boxed)))
}
}
}

View File

@ -1,7 +1,7 @@
use std::fmt;
use std::path::{Path, PathBuf};
use std::any::Any;
use std::net::{SocketAddr as TcpAddr, Ipv4Addr, AddrParseError};
use std::net::{self, AddrParseError, IpAddr, Ipv4Addr};
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::Arc;
@ -9,23 +9,23 @@ use serde::de;
use crate::http::uncased::AsUncased;
#[cfg(feature = "tls")] type TlsInfo = Option<Box<crate::tls::TlsConfig>>;
#[cfg(not(feature = "tls"))] type TlsInfo = Option<()>;
pub trait EndpointAddr: fmt::Display + fmt::Debug + Sync + Send + Any { }
impl<T: fmt::Display + fmt::Debug + Sync + Send + Any> EndpointAddr for T {}
#[cfg(not(feature = "tls"))] type TlsInfo = Option<()>;
#[cfg(feature = "tls")] type TlsInfo = Option<crate::tls::TlsConfig>;
/// # Conversions
///
/// * [`&str`] - parse with [`FromStr`]
/// * [`tokio::net::unix::SocketAddr`] - must be path: [`ListenerAddr::Unix`]
/// * [`std::net::SocketAddr`] - infallibly as [ListenerAddr::Tcp]
/// * [`PathBuf`] - infallibly as [`ListenerAddr::Unix`]
// TODO: Rename to something better. `Endpoint`?
#[derive(Debug)]
/// * [`tokio::net::unix::SocketAddr`] - must be path: [`Endpoint::Unix`]
/// * [`PathBuf`] - infallibly as [`Endpoint::Unix`]
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Endpoint {
Tcp(TcpAddr),
Tcp(net::SocketAddr),
Quic(net::SocketAddr),
Unix(PathBuf),
Tls(Arc<Endpoint>, TlsInfo),
Custom(Arc<dyn EndpointAddr>),
@ -36,9 +36,45 @@ impl Endpoint {
Endpoint::Custom(Arc::new(value))
}
pub fn tcp(&self) -> Option<TcpAddr> {
pub fn tcp(&self) -> Option<net::SocketAddr> {
match self {
Endpoint::Tcp(addr) => Some(*addr),
Endpoint::Tls(addr, _) => addr.tcp(),
_ => None,
}
}
pub fn quic(&self) -> Option<net::SocketAddr> {
match self {
Endpoint::Quic(addr) => Some(*addr),
Endpoint::Tls(addr, _) => addr.tcp(),
_ => None,
}
}
pub fn socket_addr(&self) -> Option<net::SocketAddr> {
match self {
Endpoint::Quic(addr) => Some(*addr),
Endpoint::Tcp(addr) => Some(*addr),
Endpoint::Tls(inner, _) => inner.socket_addr(),
_ => None,
}
}
pub fn ip(&self) -> Option<IpAddr> {
match self {
Endpoint::Quic(addr) => Some(addr.ip()),
Endpoint::Tcp(addr) => Some(addr.ip()),
Endpoint::Tls(inner, _) => inner.ip(),
_ => None,
}
}
pub fn port(&self) -> Option<u16> {
match self {
Endpoint::Quic(addr) => Some(addr.port()),
Endpoint::Tcp(addr) => Some(addr.port()),
Endpoint::Tls(inner, _) => inner.port(),
_ => None,
}
}
@ -46,6 +82,7 @@ impl Endpoint {
pub fn unix(&self) -> Option<&Path> {
match self {
Endpoint::Unix(addr) => Some(addr),
Endpoint::Tls(addr, _) => addr.unix(),
_ => None,
}
}
@ -76,6 +113,7 @@ impl Endpoint {
pub fn downcast<T: 'static>(&self) -> Option<&T> {
match self {
Endpoint::Tcp(addr) => (&*addr as &dyn Any).downcast_ref(),
Endpoint::Quic(addr) => (&*addr as &dyn Any).downcast_ref(),
Endpoint::Unix(addr) => (&*addr as &dyn Any).downcast_ref(),
Endpoint::Custom(addr) => (&*addr as &dyn Any).downcast_ref(),
Endpoint::Tls(inner, ..) => inner.downcast(),
@ -86,6 +124,10 @@ impl Endpoint {
self.tcp().is_some()
}
pub fn is_quic(&self) -> bool {
self.quic().is_some()
}
pub fn is_unix(&self) -> bool {
self.unix().is_some()
}
@ -95,12 +137,12 @@ impl Endpoint {
}
#[cfg(feature = "tls")]
pub fn with_tls(self, config: crate::tls::TlsConfig) -> Endpoint {
pub fn with_tls(self, tls: &crate::tls::TlsConfig) -> Endpoint {
if self.is_tls() {
return self;
}
Self::Tls(Arc::new(self), Some(config))
Self::Tls(Arc::new(self), Some(Box::new(tls.clone())))
}
pub fn assume_tls(self) -> Endpoint {
@ -117,39 +159,27 @@ impl fmt::Display for Endpoint {
use Endpoint::*;
match self {
Tcp(addr) => write!(f, "http://{addr}"),
Tcp(addr) | Quic(addr) => write!(f, "http://{addr}"),
Unix(addr) => write!(f, "unix:{}", addr.display()),
Custom(inner) => inner.fmt(f),
Tls(inner, c) => match (&**inner, c.as_ref()) {
Tls(inner, _c) => {
match (inner.tcp(), inner.quic()) {
(Some(addr), _) => write!(f, "https://{addr} (TCP")?,
(_, Some(addr)) => write!(f, "https://{addr} (QUIC")?,
(None, None) => write!(f, "{inner} (TLS")?,
}
#[cfg(feature = "mtls")]
(Tcp(i), Some(c)) if c.mutual().is_some() => write!(f, "https://{i} (TLS + MTLS)"),
(Tcp(i), _) => write!(f, "https://{i} (TLS)"),
#[cfg(feature = "mtls")]
(i, Some(c)) if c.mutual().is_some() => write!(f, "{i} (TLS + MTLS)"),
(inner, _) => write!(f, "{inner} (TLS)"),
},
if _c.as_ref().and_then(|c| c.mutual()).is_some() {
write!(f, " + mTLS")?;
}
write!(f, ")")
}
}
}
}
impl From<std::net::SocketAddr> for Endpoint {
fn from(value: std::net::SocketAddr) -> Self {
Self::Tcp(value)
}
}
impl From<std::net::SocketAddrV4> for Endpoint {
fn from(value: std::net::SocketAddrV4) -> Self {
Self::Tcp(value.into())
}
}
impl From<std::net::SocketAddrV6> for Endpoint {
fn from(value: std::net::SocketAddrV6) -> Self {
Self::Tcp(value.into())
}
}
impl From<PathBuf> for Endpoint {
fn from(value: PathBuf) -> Self {
Self::Unix(value)
@ -177,36 +207,38 @@ impl TryFrom<&str> for Endpoint {
impl Default for Endpoint {
fn default() -> Self {
Endpoint::Tcp(TcpAddr::new(Ipv4Addr::LOCALHOST.into(), 8000))
Endpoint::Tcp(net::SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8000))
}
}
/// Parses an address into a `ListenerAddr`.
/// Parses an address into a `Endpoint`.
///
/// The syntax is:
///
/// ```text
/// listener_addr = 'tcp' ':' tcp_addr | 'unix' ':' unix_addr | tcp_addr
/// tcp_addr := IP_ADDR | SOCKET_ADDR
/// unix_addr := PATH
/// endpoint = 'tcp' ':' socket | 'quic' ':' socket | 'unix' ':' path | socket
/// socket := IP_ADDR | SOCKET_ADDR
/// path := PATH
///
/// IP_ADDR := `std::net::IpAddr` string as defined by Rust
/// SOCKET_ADDR := `std::net::SocketAddr` string as defined by Rust
/// PATH := `PathBuf` (any UTF-8) string as defined by Rust
/// ```
///
/// If `IP_ADDR` is specified, the port defaults to `8000`.
/// If `IP_ADDR` is specified in socket, port defaults to `8000`.
impl FromStr for Endpoint {
type Err = AddrParseError;
fn from_str(string: &str) -> Result<Self, Self::Err> {
fn parse_tcp(string: &str, def_port: u16) -> Result<TcpAddr, AddrParseError> {
string.parse().or_else(|_| string.parse().map(|ip| TcpAddr::new(ip, def_port)))
fn parse_tcp(str: &str, def_port: u16) -> Result<net::SocketAddr, AddrParseError> {
str.parse().or_else(|_| str.parse().map(|ip| net::SocketAddr::new(ip, def_port)))
}
if let Some((proto, string)) = string.split_once(':') {
if proto.trim().as_uncased() == "tcp" {
return parse_tcp(string.trim(), 8000).map(Self::Tcp);
} else if proto.trim().as_uncased() == "quic" {
return parse_tcp(string.trim(), 8000).map(Self::Quic);
} else if proto.trim().as_uncased() == "unix" {
return Ok(Self::Unix(PathBuf::from(string.trim())));
}
@ -242,6 +274,7 @@ impl PartialEq for Endpoint {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Tcp(l0), Self::Tcp(r0)) => l0 == r0,
(Self::Quic(l0), Self::Quic(r0)) => l0 == r0,
(Self::Unix(l0), Self::Unix(r0)) => l0 == r0,
(Self::Tls(l0, _), Self::Tls(r0, _)) => l0 == r0,
(Self::Custom(l0), Self::Custom(r0)) => l0.to_string() == r0.to_string(),
@ -250,24 +283,6 @@ impl PartialEq for Endpoint {
}
}
impl PartialEq<std::net::SocketAddr> for Endpoint {
fn eq(&self, other: &std::net::SocketAddr) -> bool {
self.tcp() == Some(*other)
}
}
impl PartialEq<std::net::SocketAddrV4> for Endpoint {
fn eq(&self, other: &std::net::SocketAddrV4) -> bool {
self.tcp() == Some((*other).into())
}
}
impl PartialEq<std::net::SocketAddrV6> for Endpoint {
fn eq(&self, other: &std::net::SocketAddrV6) -> bool {
self.tcp() == Some((*other).into())
}
}
impl PartialEq<PathBuf> for Endpoint {
fn eq(&self, other: &PathBuf) -> bool {
self.unix() == Some(other.as_path())

View File

@ -10,12 +10,13 @@ pub trait Listener: Send + Sync {
type Connection: Connection;
#[crate::async_bound(Send)]
async fn accept(&self) -> io::Result<Self::Accept>;
#[crate::async_bound(Send)]
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection>;
fn socket_addr(&self) -> io::Result<Endpoint>;
fn endpoint(&self) -> io::Result<Endpoint>;
}
impl<L: Listener> Listener for &L {
@ -31,8 +32,8 @@ impl<L: Listener> Listener for &L {
<L as Listener>::connect(self, accept).await
}
fn socket_addr(&self) -> io::Result<Endpoint> {
<L as Listener>::socket_addr(self)
fn endpoint(&self) -> io::Result<Endpoint> {
<L as Listener>::endpoint(self)
}
}
@ -56,10 +57,10 @@ impl<A: Listener, B: Listener> Listener for Either<A, B> {
}
}
fn socket_addr(&self) -> io::Result<Endpoint> {
fn endpoint(&self) -> io::Result<Endpoint> {
match self {
Either::Left(l) => l.socket_addr(),
Either::Right(l) => l.socket_addr(),
Either::Left(l) => l.endpoint(),
Either::Right(l) => l.endpoint(),
}
}
}

View File

@ -13,6 +13,8 @@ pub mod unix;
#[cfg_attr(nightly, doc(cfg(feature = "tls")))]
pub mod tls;
pub mod tcp;
#[cfg(feature = "http3-preview")]
pub mod quic;
pub use endpoint::*;
pub use listener::*;

View File

@ -0,0 +1,234 @@
//! Experimental support for Quic and HTTP/3.
//!
//! To enable Rocket's experimental support for HTTP/3 and Quic, enable the
//! `http3-preview` feature and provide a valid TLS configuration:
//!
//! ```toml
//! // Add the following to your Cargo.toml:
//! [dependencies]
//! rocket = { version = "0.6.0-dev", features = ["http3-preview"] }
//!
//! // In your Rocket.toml or other equivalent config source:
//! [default.tls]
//! certs = "private/rsa_sha256_cert.pem"
//! key = "private/rsa_sha256_key.pem"
//! ```
//!
//! The launch message confirms that Rocket is serving traffic over Quic in
//! addition to TCP:
//!
//! ```sh
//! > 🚀 Rocket has launched on https://127.0.0.1:8000 (QUIC + mTLS)
//! ```
//!
//! mTLS is not yet supported via this implementation.
use std::io;
use std::fmt;
use std::net::SocketAddr;
use std::pin::pin;
use s2n_quic as quic;
use s2n_quic_h3 as quic_h3;
use quic_h3::h3 as h3;
use bytes::Bytes;
use futures::Stream;
use tokio::sync::Mutex;
use tokio_stream::StreamExt;
use crate::tls::TlsConfig;
use crate::listener::{Listener, Connection, Endpoint};
type H3Conn = h3::server::Connection<quic_h3::Connection, bytes::Bytes>;
pub struct QuicListener {
endpoint: SocketAddr,
listener: Mutex<quic::Server>,
tls: TlsConfig,
}
pub struct H3Stream(H3Conn);
pub struct H3Connection {
pub handle: quic::connection::Handle,
pub parts: http::request::Parts,
pub tx: QuicTx,
pub rx: QuicRx,
}
pub struct QuicRx(h3::server::RequestStream<quic_h3::RecvStream, Bytes>);
pub struct QuicTx(h3::server::RequestStream<quic_h3::SendStream<Bytes>, Bytes>);
impl QuicListener {
pub async fn bind(address: SocketAddr, tls: TlsConfig) -> Result<Self, io::Error> {
use quic::provider::tls::rustls::{rustls, DEFAULT_CIPHERSUITES, Server as H3TlsServer};
// FIXME: Remove this as soon as `s2n_quic` is on rustls >= 0.22.
let cert_chain = crate::tls::util::load_cert_chain(&mut tls.certs_reader().unwrap())
.unwrap()
.into_iter()
.map(|v| v.to_vec())
.map(rustls::Certificate)
.collect::<Vec<_>>();
let key = crate::tls::util::load_key(&mut tls.key_reader().unwrap())
.unwrap()
.secret_der()
.to_vec();
let mut h3tls = rustls::server::ServerConfig::builder()
.with_cipher_suites(DEFAULT_CIPHERSUITES)
.with_safe_default_kx_groups()
.with_safe_default_protocol_versions()
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?
.with_client_cert_verifier(rustls::server::NoClientAuth::boxed())
.with_single_cert(cert_chain, rustls::PrivateKey(key))
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?;
h3tls.alpn_protocols = vec![b"h3".to_vec()];
h3tls.ignore_client_order = tls.prefer_server_cipher_order;
h3tls.session_storage = rustls::server::ServerSessionMemoryCache::new(1024);
h3tls.ticketer = rustls::Ticketer::new()
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS ticketer: {}", e)))?;
let listener = quic::Server::builder()
.with_tls(H3TlsServer::new(h3tls))
.unwrap_or_else(|e| match e { })
.with_io(address)?
.start()
.map_err(io::Error::other)?;
Ok(QuicListener {
tls,
endpoint: listener.local_addr()?,
listener: Mutex::new(listener),
})
}
}
impl Listener for QuicListener {
type Accept = quic::Connection;
type Connection = H3Stream;
async fn accept(&self) -> io::Result<Self::Accept> {
self.listener
.lock().await
.accept().await
.ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "closed"))
}
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> {
let quic_conn = quic_h3::Connection::new(accept);
let conn = H3Conn::new(quic_conn).await.map_err(io::Error::other)?;
Ok(H3Stream(conn))
}
fn endpoint(&self) -> io::Result<Endpoint> {
Ok(Endpoint::Quic(self.endpoint).with_tls(&self.tls))
}
}
impl H3Stream {
pub async fn accept(&mut self) -> io::Result<Option<H3Connection>> {
let handle = self.0.inner.conn.handle().clone();
let ((parts, _), (tx, rx)) = match self.0.accept().await {
Ok(Some((req, stream))) => (req.into_parts(), stream.split()),
Ok(None) => return Ok(None),
Err(e) => {
if matches!(e.try_get_code().map(|c| c.value()), Some(0 | 0x100)) {
return Ok(None)
}
return Err(io::Error::other(e));
}
};
Ok(Some(H3Connection { handle, parts, tx: QuicTx(tx), rx: QuicRx(rx) }))
}
}
impl QuicTx {
pub async fn send_response<S>(&mut self, response: http::Response<S>) -> io::Result<()>
where S: Stream<Item = io::Result<Bytes>>
{
let (parts, body) = response.into_parts();
let response = http::Response::from_parts(parts, ());
self.0.send_response(response).await.map_err(io::Error::other)?;
let mut body = pin!(body);
while let Some(bytes) = body.next().await {
let bytes = bytes.map_err(io::Error::other)?;
self.0.send_data(bytes).await.map_err(io::Error::other)?;
}
self.0.finish().await.map_err(io::Error::other)
}
pub fn cancel(&mut self) {
self.0.stop_stream(h3::error::Code::H3_NO_ERROR);
}
}
// FIXME: Expose certificates when possible.
impl Connection for H3Stream {
fn endpoint(&self) -> io::Result<Endpoint> {
let addr = self.0.inner.conn.handle().remote_addr()?;
Ok(Endpoint::Quic(addr).assume_tls())
}
}
// FIXME: Expose certificates when possible.
impl Connection for H3Connection {
fn endpoint(&self) -> io::Result<Endpoint> {
let addr = self.handle.remote_addr()?;
Ok(Endpoint::Quic(addr).assume_tls())
}
}
mod async_traits {
use std::io;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use super::{Bytes, QuicRx};
use crate::listener::AsyncCancel;
use futures::Stream;
use s2n_quic_h3::h3;
impl Stream for QuicRx {
type Item = io::Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use bytes::Buf;
match ready!(self.0.poll_recv_data(cx)) {
Ok(Some(mut buf)) => Poll::Ready(Some(Ok(buf.copy_to_bytes(buf.remaining())))),
Ok(None) => Poll::Ready(None),
Err(e) => Poll::Ready(Some(Err(io::Error::other(e)))),
}
}
}
impl AsyncCancel for QuicRx {
fn poll_cancel(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
self.0.stop_sending(h3::error::Code::H3_NO_ERROR);
Poll::Ready(Ok(()))
}
}
}
impl fmt::Debug for H3Stream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("H3Stream").finish()
}
}
impl fmt::Debug for H3Connection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("H3Connection").finish()
}
}

View File

@ -13,6 +13,10 @@ impl Bindable for std::net::SocketAddr {
async fn bind(self) -> Result<Self::Listener, Self::Error> {
TcpListener::bind(self).await
}
fn candidate_endpoint(&self) -> io::Result<Endpoint> {
Ok(Endpoint::Tcp(*self))
}
}
impl Listener for TcpListener {
@ -31,13 +35,13 @@ impl Listener for TcpListener {
Ok(conn)
}
fn socket_addr(&self) -> io::Result<Endpoint> {
fn endpoint(&self) -> io::Result<Endpoint> {
self.local_addr().map(Endpoint::Tcp)
}
}
impl Connection for TcpStream {
fn peer_address(&self) -> io::Result<Endpoint> {
fn endpoint(&self) -> io::Result<Endpoint> {
self.peer_addr().map(Endpoint::Tcp)
}
}

View File

@ -3,6 +3,7 @@ use std::sync::Arc;
use serde::Deserialize;
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsAcceptor;
use crate::tls::{TlsConfig, Error};
@ -27,7 +28,7 @@ pub struct TlsBindable<I> {
}
impl TlsConfig {
pub(crate) fn acceptor(&self) -> Result<tokio_rustls::TlsAcceptor, Error> {
pub(crate) fn server_config(&self) -> Result<ServerConfig, Error> {
let provider = rustls::crypto::CryptoProvider {
cipher_suites: self.ciphers().map(|c| c.into()).collect(),
..rustls::crypto::ring::default_provider()
@ -64,52 +65,60 @@ impl TlsConfig {
tls_config.alpn_protocols.insert(0, b"h2".to_vec());
}
Ok(TlsAcceptor::from(Arc::new(tls_config)))
Ok(tls_config)
}
}
impl<I: Bindable> Bindable for TlsBindable<I> {
impl<I: Bindable> Bindable for TlsBindable<I>
where I::Listener: Listener<Accept = <I::Listener as Listener>::Connection>,
<I::Listener as Listener>::Connection: AsyncRead + AsyncWrite
{
type Listener = TlsListener<I::Listener>;
type Error = Error;
async fn bind(self) -> Result<Self::Listener, Self::Error> {
Ok(TlsListener {
acceptor: self.tls.acceptor()?,
acceptor: TlsAcceptor::from(Arc::new(self.tls.server_config()?)),
listener: self.inner.bind().await.map_err(|e| Error::Bind(Box::new(e)))?,
config: self.tls,
})
}
fn candidate_endpoint(&self) -> io::Result<Endpoint> {
let inner = self.inner.candidate_endpoint()?;
Ok(inner.with_tls(&self.tls))
}
}
impl<L: Listener + Sync> Listener for TlsListener<L>
where L::Connection: Unpin
impl<L> Listener for TlsListener<L>
where L: Listener<Accept = <L as Listener>::Connection>,
L::Connection: AsyncRead + AsyncWrite
{
type Accept = L::Accept;
type Accept = L::Connection;
type Connection = TlsStream<L::Connection>;
async fn accept(&self) -> io::Result<Self::Accept> {
self.listener.accept().await
Ok(self.listener.accept().await?)
}
async fn connect(&self, accept: L::Accept) -> io::Result<Self::Connection> {
let conn = self.listener.connect(accept).await?;
async fn connect(&self, conn: L::Connection) -> io::Result<Self::Connection> {
self.acceptor.accept(conn).await
}
fn socket_addr(&self) -> io::Result<Endpoint> {
Ok(self.listener.socket_addr()?.with_tls(self.config.clone()))
fn endpoint(&self) -> io::Result<Endpoint> {
Ok(self.listener.endpoint()?.with_tls(&self.config))
}
}
impl<C: Connection + Unpin> Connection for TlsStream<C> {
fn peer_address(&self) -> io::Result<Endpoint> {
Ok(self.get_ref().0.peer_address()?.assume_tls())
impl<C: Connection> Connection for TlsStream<C> {
fn endpoint(&self) -> io::Result<Endpoint> {
Ok(self.get_ref().0.endpoint()?.assume_tls())
}
#[cfg(feature = "mtls")]
fn peer_certificates(&self) -> Option<Certificates<'_>> {
fn certificates(&self) -> Option<Certificates<'_>> {
let cert_chain = self.get_ref().1.peer_certificates()?;
Some(Certificates::from(cert_chain))
}

View File

@ -68,6 +68,10 @@ impl Bindable for UdsConfig {
Ok(UdsListener { lock, listener, path: self.path, })
}
fn candidate_endpoint(&self) -> io::Result<Endpoint> {
Ok(Endpoint::Unix(self.path.clone()))
}
}
impl Listener for UdsListener {
@ -83,13 +87,13 @@ impl Listener for UdsListener {
Ok(accept)
}
fn socket_addr(&self) -> io::Result<Endpoint> {
fn endpoint(&self) -> io::Result<Endpoint> {
self.listener.local_addr()?.try_into()
}
}
impl Connection for UnixStream {
fn peer_address(&self) -> io::Result<Endpoint> {
fn endpoint(&self) -> io::Result<Endpoint> {
self.local_addr()?.try_into()
}
}

View File

@ -59,12 +59,12 @@ impl Client {
tracked: bool,
secure: bool,
) -> Result<Client, Error> {
let mut listener = Endpoint::new("local client");
let mut endpoint = Endpoint::new("local client");
if secure {
listener = listener.assume_tls();
endpoint = endpoint.assume_tls();
}
let rocket = rocket.local_launch(listener).await?;
let rocket = rocket.local_launch(endpoint).await?;
let cookies = RwLock::new(cookie::CookieJar::new());
Ok(Client { rocket, cookies, tracked })
}

View File

@ -99,26 +99,27 @@ macro_rules! pub_request_impl {
/// Set the remote address of this request to `address`.
///
/// `address` may be any type that [can be converted into a `ListenerAddr`].
/// `address` may be any type that [can be converted into a `Endpoint`].
/// If `address` fails to convert, the remote is left unchanged.
///
/// [can be converted into a `ListenerAddr`]: crate::listener::ListenerAddr#conversions
/// [can be converted into a `Endpoint`]: crate::listener::Endpoint#conversions
///
/// # Examples
///
/// Set the remote address to "8.8.8.8:80":
///
/// ```rust
/// use std::net::{SocketAddrV4, Ipv4Addr};
/// use std::net::Ipv4Addr;
///
#[doc = $import]
///
/// # Client::_test(|_, request, _| {
/// let request: LocalRequest = request;
/// let req = request.remote("8.8.8.8:80");
/// let req = request.remote("tcp:8.8.8.8:80");
///
/// let addr = SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8).into(), 80);
/// assert_eq!(req.inner().remote().unwrap(), &addr);
/// let remote = req.inner().remote().unwrap().tcp().unwrap();
/// assert_eq!(remote.ip(), Ipv4Addr::new(8, 8, 8, 8));
/// assert_eq!(remote.port(), 80);
/// # });
/// ```
#[inline]

View File

@ -2,7 +2,8 @@ use state::TypeMap;
use figment::Figment;
use crate::listener::Endpoint;
use crate::{Catcher, Config, Rocket, Route, Shutdown};
use crate::shutdown::Stages;
use crate::{Catcher, Config, Rocket, Route};
use crate::router::Router;
use crate::fairing::Fairings;
@ -99,7 +100,7 @@ phases! {
pub(crate) figment: Figment,
pub(crate) config: Config,
pub(crate) state: TypeMap![Send + Sync],
pub(crate) shutdown: Shutdown,
pub(crate) shutdown: Stages,
}
/// The final launch [`Phase`]. See [Rocket#orbit](`Rocket#orbit`) for
@ -113,7 +114,7 @@ phases! {
pub(crate) figment: Figment,
pub(crate) config: Config,
pub(crate) state: TypeMap![Send + Sync],
pub(crate) shutdown: Shutdown,
pub(crate) endpoint: Endpoint,
pub(crate) shutdown: Stages,
pub(crate) endpoints: Vec<Endpoint>,
}
}

View File

@ -501,7 +501,7 @@ impl<'r> FromRequest<'r> for std::net::SocketAddr {
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Infallible> {
request.remote()
.and_then(|r| r.tcp())
.and_then(|r| r.socket_addr())
.or_forward(Status::InternalServerError)
}
}

View File

@ -39,7 +39,7 @@ pub struct Request<'r> {
/// Information derived from an incoming connection, if any.
#[derive(Clone, Default)]
pub(crate) struct ConnectionMeta {
pub peer_address: Option<Arc<Endpoint>>,
pub peer_endpoint: Option<Endpoint>,
#[cfg_attr(not(feature = "mtls"), allow(dead_code))]
pub peer_certs: Option<Arc<Certificates<'static>>>,
}
@ -47,8 +47,8 @@ pub(crate) struct ConnectionMeta {
impl<C: Connection> From<&C> for ConnectionMeta {
fn from(conn: &C) -> Self {
ConnectionMeta {
peer_address: conn.peer_address().ok().map(Arc::new),
peer_certs: conn.peer_certificates().map(|c| c.into_owned()).map(Arc::new),
peer_endpoint: conn.endpoint().ok(),
peer_certs: conn.certificates().map(|c| c.into_owned()).map(Arc::new),
}
}
}
@ -316,20 +316,21 @@ impl<'r> Request<'r> {
/// # Example
///
/// ```rust
/// use std::net::{SocketAddrV4, Ipv4Addr};
/// use std::net::{IpAddr, Ipv4Addr, SocketAddr};
/// use rocket::listener::Endpoint;
/// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap();
/// # let mut req = c.get("/");
/// # let request = req.inner_mut();
///
/// assert_eq!(request.remote(), None);
///
/// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8000);
/// request.set_remote(localhost);
/// assert_eq!(request.remote().unwrap(), &localhost);
/// let localhost = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8111);
/// request.set_remote(Endpoint::Tcp(localhost));
/// assert_eq!(request.remote().unwrap().tcp().unwrap(), localhost);
/// ```
#[inline(always)]
pub fn remote(&self) -> Option<&Endpoint> {
self.connection.peer_address.as_deref()
self.connection.peer_endpoint.as_ref()
}
/// Sets the remote address of `self` to `address`.
@ -339,20 +340,21 @@ impl<'r> Request<'r> {
/// Set the remote address to be 127.0.0.1:8111:
///
/// ```rust
/// use std::net::{SocketAddrV4, Ipv4Addr};
/// use std::net::{IpAddr, Ipv4Addr, SocketAddr};
/// use rocket::listener::Endpoint;
/// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap();
/// # let mut req = c.get("/");
/// # let request = req.inner_mut();
///
/// assert_eq!(request.remote(), None);
///
/// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8111);
/// request.set_remote(localhost);
/// assert_eq!(request.remote().unwrap(), &localhost);
/// let localhost = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8111);
/// request.set_remote(Endpoint::Tcp(localhost));
/// assert_eq!(request.remote().unwrap().tcp().unwrap(), localhost);
/// ```
#[inline(always)]
pub fn set_remote<A: Into<Endpoint>>(&mut self, address: A) {
self.connection.peer_address = Some(Arc::new(address.into()));
pub fn set_remote(&mut self, endpoint: Endpoint) {
self.connection.peer_endpoint = Some(endpoint.into());
}
/// Returns the IP address of the configured
@ -491,14 +493,15 @@ impl<'r> Request<'r> {
/// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap();
/// # let mut req = c.get("/");
/// # let request = req.inner_mut();
/// # use std::net::{SocketAddrV4, Ipv4Addr};
/// # use std::net::{SocketAddr, IpAddr, Ipv4Addr};
/// # use rocket::listener::Endpoint;
///
/// // starting without an "X-Real-IP" header or remote address
/// assert!(request.client_ip().is_none());
///
/// // add a remote address; this is done by Rocket automatically
/// let localhost_9190 = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 9190);
/// request.set_remote(localhost_9190);
/// let localhost_9190 = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9190);
/// request.set_remote(Endpoint::Tcp(localhost_9190));
/// assert_eq!(request.client_ip().unwrap(), Ipv4Addr::LOCALHOST);
///
/// // now with an X-Real-IP header, the default value for `ip_header`.
@ -507,7 +510,7 @@ impl<'r> Request<'r> {
/// ```
#[inline]
pub fn client_ip(&self) -> Option<IpAddr> {
self.real_ip().or_else(|| Some(self.remote()?.tcp()?.ip()))
self.real_ip().or_else(|| self.remote()?.ip())
}
/// Returns a wrapped borrow to the cookies in `self`.

View File

@ -1,14 +1,17 @@
use std::fmt;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::time::Duration;
use yansi::Paint;
use either::Either;
use figment::{Figment, Provider};
use tokio::io::{AsyncRead, AsyncWrite};
use crate::{Catcher, Config, Route, Shutdown, sentinel, shield::Shield};
use crate::listener::{Endpoint, Bindable, DefaultListener};
use crate::shutdown::{Stages, Shutdown};
use crate::{sentinel, shield::Shield, Catcher, Config, Route};
use crate::listener::{Bindable, DefaultListener, Endpoint, Listener};
use crate::router::Router;
use crate::util::TripWire;
use crate::fairing::{Fairing, Fairings};
use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting};
use crate::phase::{Stateful, StateRef, State};
@ -575,11 +578,11 @@ impl Rocket<Build> {
// Ignite the rocket.
let rocket: Rocket<Ignite> = Rocket(Igniting {
router, config,
shutdown: Shutdown(TripWire::new()),
shutdown: Stages::new(),
figment: self.0.figment,
fairings: self.0.fairings,
state: self.0.state,
router, config,
});
// Query the sentinels, abort if requested.
@ -630,7 +633,7 @@ impl Rocket<Ignite> {
/// A completed graceful shutdown resolves the future returned by
/// [`Rocket::launch()`]. If [`Shutdown::notify()`] is called _before_ an
/// instance is launched, it will be immediately shutdown after liftoff. See
/// [`Shutdown`] and [`config::Shutdown`](crate::config::Shutdown) for
/// [`Shutdown`] and [`ShutdownConfig`](crate::config::ShutdownConfig) for
/// details on graceful shutdown.
///
/// # Example
@ -657,12 +660,12 @@ impl Rocket<Ignite> {
/// }
/// ```
pub fn shutdown(&self) -> Shutdown {
self.shutdown.clone()
self.shutdown.start.clone()
}
pub(crate) fn into_orbit(self, address: Endpoint) -> Rocket<Orbit> {
pub(crate) fn into_orbit(self, endpoints: Vec<Endpoint>) -> Rocket<Orbit> {
Rocket(Orbiting {
endpoint: address,
endpoints,
router: self.0.router,
fairings: self.0.fairings,
figment: self.0.figment,
@ -672,8 +675,8 @@ impl Rocket<Ignite> {
})
}
async fn _local_launch(self, addr: Endpoint) -> Rocket<Orbit> {
let rocket = self.into_orbit(addr);
async fn _local_launch(self, endpoint: Endpoint) -> Rocket<Orbit> {
let rocket = self.into_orbit(vec![endpoint]);
Rocket::liftoff(&rocket).await;
rocket
}
@ -687,13 +690,72 @@ impl Rocket<Ignite> {
})
}
async fn _launch_on<B: Bindable>(self, bindable: B) -> Result<Rocket<Ignite>, Error> {
let listener = bindable.bind().await.map_err(|e| ErrorKind::Bind(Box::new(e)))?;
self.serve(listener).await
async fn _launch_on<B: Bindable>(self, bindable: B) -> Result<Rocket<Ignite>, Error>
where <B::Listener as Listener>::Connection: AsyncRead + AsyncWrite
{
let rocket = self.bind_and_serve(bindable, |rocket| async move {
let rocket = Arc::new(rocket);
rocket.shutdown.spawn_listener(&rocket.config.shutdown);
if let Err(e) = tokio::spawn(Rocket::liftoff(rocket.clone())).await {
let rocket = rocket.try_wait_shutdown().await;
return Err(ErrorKind::Liftoff(rocket, Box::new(e)).into());
}
Ok(rocket)
}).await?;
Ok(rocket.try_wait_shutdown().await.map_err(ErrorKind::Shutdown)?)
}
}
impl Rocket<Orbit> {
/// Rocket wraps all connections in a `CancellableIo` struct, an internal
/// structure that gracefully closes I/O when it receives a signal. That
/// signal is the `shutdown` future. When the future resolves,
/// `CancellableIo` begins to terminate in grace, mercy, and finally force
/// close phases. Since all connections are wrapped in `CancellableIo`, this
/// eventually ends all I/O.
///
/// At that point, unless a user spawned an infinite, stand-alone task that
/// isn't monitoring `Shutdown`, all tasks should resolve. This means that
/// all instances of the shared `Arc<Rocket>` are dropped and we can return
/// the owned instance of `Rocket`.
///
/// Unfortunately, the Hyper `server` future resolves as soon as it has
/// finished processing requests without respect for ongoing responses. That
/// is, `server` resolves even when there are running tasks that are
/// generating a response. So, `server` resolving implies little to nothing
/// about the state of connections. As a result, we depend on the timing of
/// grace + mercy + some buffer to determine when all connections should be
/// closed, thus all tasks should be complete, thus all references to
/// `Arc<Rocket>` should be dropped and we can get back a unique reference.
async fn try_wait_shutdown(self: Arc<Self>) -> Result<Rocket<Ignite>, Arc<Self>> {
info!("Shutting down. Waiting for shutdown fairings and pending I/O...");
tokio::spawn({
let rocket = self.clone();
async move { rocket.fairings.handle_shutdown(&*rocket).await }
});
let config = &self.config.shutdown;
let wait = Duration::from_micros(250);
for period in [wait, config.grace(), wait, config.mercy(), wait * 4] {
if Arc::strong_count(&self) == 1 { break }
tokio::time::sleep(period).await;
}
match Arc::try_unwrap(self) {
Ok(rocket) => {
info!("Graceful shutdown completed successfully.");
Ok(rocket.into_ignite())
}
Err(rocket) => {
warn!("Shutdown failed: outstanding background I/O.");
Err(rocket)
}
}
}
pub(crate) fn into_ignite(self) -> Rocket<Ignite> {
Rocket(Igniting {
router: self.0.router,
@ -717,7 +779,7 @@ impl Rocket<Orbit> {
launch_info!("{}{} {}", "🚀 ".emoji(),
"Rocket has launched on".bold().primary().linger(),
rocket.endpoint().underline());
rocket.endpoints[0].underline());
}
/// Returns the finalized, active configuration. This is guaranteed to
@ -742,8 +804,8 @@ impl Rocket<Orbit> {
&self.config
}
pub fn endpoint(&self) -> &Endpoint {
&self.endpoint
pub fn endpoints(&self) -> impl Iterator<Item = &Endpoint> {
self.endpoints.iter()
}
/// Returns a handle which can be used to trigger a shutdown and detect a
@ -751,8 +813,8 @@ impl Rocket<Orbit> {
///
/// A completed graceful shutdown resolves the future returned by
/// [`Rocket::launch()`]. See [`Shutdown`] and
/// [`config::Shutdown`](crate::config::Shutdown) for details on graceful
/// shutdown.
/// [`ShutdownConfig`](crate::config::ShutdownConfig) for details on
/// graceful shutdown.
///
/// # Example
///
@ -774,7 +836,7 @@ impl Rocket<Orbit> {
/// }
/// ```
pub fn shutdown(&self) -> Shutdown {
self.shutdown.clone()
self.shutdown.start.clone()
}
}
@ -879,10 +941,10 @@ impl<P: Phase> Rocket<P> {
}
}
pub(crate) async fn local_launch(self, l: Endpoint) -> Result<Rocket<Orbit>, Error> {
pub(crate) async fn local_launch(self, e: Endpoint) -> Result<Rocket<Orbit>, Error> {
let rocket = match self.0.into_state() {
State::Build(s) => Rocket::from(s).ignite().await?._local_launch(l).await,
State::Ignite(s) => Rocket::from(s)._local_launch(l).await,
State::Build(s) => Rocket::from(s).ignite().await?._local_launch(e).await,
State::Ignite(s) => Rocket::from(s)._local_launch(e).await,
State::Orbit(s) => Rocket::from(s)
};
@ -941,7 +1003,9 @@ impl<P: Phase> Rocket<P> {
}
}
pub async fn launch_on<B: Bindable>(self, bindable: B) -> Result<Rocket<Ignite>, Error> {
pub async fn launch_on<B: Bindable>(self, bindable: B) -> Result<Rocket<Ignite>, Error>
where <B::Listener as Listener>::Connection: AsyncRead + AsyncWrite
{
match self.0.into_state() {
State::Build(s) => Rocket::from(s).ignite().await?._launch_on(bindable).await,
State::Ignite(s) => Rocket::from(s)._launch_on(bindable).await,

View File

@ -213,8 +213,8 @@ mod tests {
use std::str::FromStr;
use super::*;
use crate::route::{Route, dummy_handler};
use crate::http::{Method, Method::*, MediaType};
use crate::route::dummy_handler;
use crate::http::{Method, Method::*};
fn dummy_route(ranked: bool, method: impl Into<Option<Method>>, uri: &'static str) -> Route {
let method = method.into().unwrap_or(Get);

View File

@ -102,7 +102,7 @@ mod test {
use crate::route::dummy_handler;
use crate::local::blocking::Client;
use crate::http::{Method, Method::*, uri::Origin};
use crate::http::{Method::*, uri::Origin};
impl Router {
fn has_collisions(&self) -> bool {

View File

@ -6,33 +6,35 @@ use std::time::Duration;
use hyper::service::service_fn;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use hyper_util::server::conn::auto::Builder;
use futures::{Future, TryFutureExt, future::{select, Either::*}};
use tokio::time::sleep;
use futures::{Future, TryFutureExt, future::Either::*};
use tokio::io::{AsyncRead, AsyncWrite};
use crate::{Request, Rocket, Orbit, Data, Ignite};
use crate::{Ignite, Orbit, Request, Rocket};
use crate::request::ConnectionMeta;
use crate::erased::{ErasedRequest, ErasedResponse, ErasedIoHandler};
use crate::listener::{Listener, CancellableExt, BouncedExt};
use crate::error::{Error, ErrorKind};
use crate::data::IoStream;
use crate::util::ReaderStream;
use crate::listener::{Bindable, BouncedExt, CancellableExt, Listener};
use crate::error::{log_server_error, ErrorKind};
use crate::data::{IoStream, RawStream};
use crate::util::{spawn_inspect, FutureExt, ReaderStream};
use crate::http::Status;
type Result<T, E = crate::Error> = std::result::Result<T, E>;
impl Rocket<Orbit> {
async fn service(
async fn service<T: for<'a> Into<RawStream<'a>>>(
self: Arc<Self>,
mut req: hyper::Request<hyper::body::Incoming>,
parts: http::request::Parts,
stream: T,
upgrade: Option<hyper::upgrade::OnUpgrade>,
connection: ConnectionMeta,
) -> Result<hyper::Response<ReaderStream<ErasedResponse>>, http::Error> {
let upgrade = hyper::upgrade::on(&mut req);
let (parts, incoming) = req.into_parts();
let alt_svc = self.alt_svc();
let request = ErasedRequest::new(self, parts, |rocket, parts| {
Request::from_hyp(rocket, parts, connection).unwrap_or_else(|e| e)
});
let mut response = request.into_response(
incoming,
|incoming| Data::from(incoming),
stream,
|rocket, request, data| Box::pin(rocket.preprocess(request, data)),
|token, rocket, request, data| Box::pin(async move {
if !request.errors.is_empty() {
@ -46,7 +48,7 @@ impl Rocket<Orbit> {
).await;
let io_handler = response.to_io_handler(Rocket::extract_io_handler);
if let Some(handler) = io_handler {
if let (Some(handler), Some(upgrade)) = (io_handler, upgrade) {
let upgrade = upgrade.map_ok(IoStream::from).map_err(io::Error::other);
tokio::task::spawn(io_handler_task(upgrade, handler));
}
@ -58,12 +60,28 @@ impl Rocket<Orbit> {
}
if let Some(size) = response.inner().body().preset_size() {
builder = builder.header("Content-Length", size);
builder = builder.header(http::header::CONTENT_TYPE, size);
}
if let Some(alt_svc) = alt_svc {
let value = http::HeaderValue::from_static(alt_svc);
builder = builder.header(http::header::ALT_SVC, value);
}
let chunk_size = response.inner().body().max_chunk_size();
builder.body(ReaderStream::with_capacity(response, chunk_size))
}
fn alt_svc(&self) -> Option<&'static str> {
cfg!(feature = "http3-preview").then(|| {
static ALT_SVC: state::InitCell<Option<String>> = state::InitCell::new();
ALT_SVC.get_or_init(|| {
let addr = self.endpoints().find_map(|v| v.quic())?;
Some(format!("h3=\":{}\"", addr.port()))
}).as_deref()
})?
}
}
async fn io_handler_task<S>(stream: S, mut handler: ErasedIoHandler)
@ -84,8 +102,51 @@ async fn io_handler_task<S>(stream: S, mut handler: ErasedIoHandler)
}
impl Rocket<Ignite> {
pub(crate) async fn serve<L>(self, listener: L) -> Result<Self, crate::Error>
where L: Listener + 'static
pub(crate) async fn bind_and_serve<B, R>(
self,
bindable: B,
post_bind_callback: impl FnOnce(Rocket<Orbit>) -> R,
) -> Result<Arc<Rocket<Orbit>>>
where B: Bindable,
<B::Listener as Listener>::Connection: AsyncRead + AsyncWrite,
R: Future<Output = Result<Arc<Rocket<Orbit>>>>
{
let binding_endpoint = bindable.candidate_endpoint().ok();
let h12listener = bindable.bind()
.map_err(|e| ErrorKind::Bind(binding_endpoint, Box::new(e)))
.await?;
let endpoint = h12listener.endpoint()?;
#[cfg(feature = "http3-preview")]
if let (Some(addr), Some(tls)) = (endpoint.tcp(), endpoint.tls_config()) {
let h3listener = crate::listener::quic::QuicListener::bind(addr, tls.clone()).await?;
let rocket = self.into_orbit(vec![h3listener.endpoint()?, endpoint]);
let rocket = post_bind_callback(rocket).await?;
let http12 = tokio::task::spawn(rocket.clone().serve12(h12listener));
let http3 = tokio::task::spawn(rocket.clone().serve3(h3listener));
let (r1, r2) = tokio::join!(http12, http3);
r1.map_err(|e| ErrorKind::Liftoff(Err(rocket.clone()), Box::new(e)))??;
r2.map_err(|e| ErrorKind::Liftoff(Err(rocket.clone()), Box::new(e)))??;
return Ok(rocket);
}
if cfg!(feature = "http3-preview") {
warn!("HTTP/3 cannot start without a valid TCP + TLS configuration.");
info_!("Falling back to HTTP/1 + HTTP/2 server.");
}
let rocket = self.into_orbit(vec![endpoint]);
let rocket = post_bind_callback(rocket).await?;
rocket.clone().serve12(h12listener).await?;
Ok(rocket)
}
}
impl Rocket<Orbit> {
pub(crate) async fn serve12<L>(self: Arc<Self>, listener: L) -> Result<()>
where L: Listener + 'static,
L::Connection: AsyncRead + AsyncWrite
{
let mut builder = Builder::new(TokioExecutor::new());
let keep_alive = Duration::from_secs(self.config.keep_alive.into());
@ -106,75 +167,64 @@ impl Rocket<Ignite> {
}
}
let listener = listener.bounced().cancellable(self.shutdown(), &self.config.shutdown);
let rocket = Arc::new(self.into_orbit(listener.socket_addr()?));
let _ = tokio::spawn(Rocket::liftoff(rocket.clone())).await;
let (listener, server) = (Arc::new(listener.bounced()), Arc::new(builder));
while let Some(accept) = listener.accept().unless(self.shutdown()).await? {
let (listener, rocket, server) = (listener.clone(), self.clone(), server.clone());
spawn_inspect(|e| log_server_error(&**e), async move {
let conn = listener.connect(accept).io_unless(rocket.shutdown()).await?;
let meta = ConnectionMeta::from(&conn);
let service = service_fn(|mut req| {
let upgrade = hyper::upgrade::on(&mut req);
let (parts, incoming) = req.into_parts();
rocket.clone().service(parts, incoming, Some(upgrade), meta.clone())
});
let (server, listener) = (Arc::new(builder), Arc::new(listener));
while let Some(accept) = listener.accept_next().await {
let (listener, rocket, server) = (listener.clone(), rocket.clone(), server.clone());
tokio::spawn({
let result = async move {
let conn = TokioIo::new(listener.connect(accept).await?);
let meta = ConnectionMeta::from(conn.inner());
let service = service_fn(|req| rocket.clone().service(req, meta.clone()));
let serve = pin!(server.serve_connection_with_upgrades(conn, service));
match select(serve, rocket.shutdown()).await {
Left((result, _)) => result,
Right((_, mut conn)) => {
conn.as_mut().graceful_shutdown();
conn.await
}
}
};
result.inspect_err(crate::error::log_server_error)
let io = TokioIo::new(conn.cancellable(rocket.shutdown.clone()));
let mut server = pin!(server.serve_connection_with_upgrades(io, service));
match server.as_mut().or(rocket.shutdown()).await {
Left(result) => result,
Right(()) => {
server.as_mut().graceful_shutdown();
server.await
},
}
});
}
// Rocket wraps all connections in a `CancellableIo` struct, an internal
// structure that gracefully closes I/O when it receives a signal. That
// signal is the `shutdown` future. When the future resolves,
// `CancellableIo` begins to terminate in grace, mercy, and finally
// force close phases. Since all connections are wrapped in
// `CancellableIo`, this eventually ends all I/O.
//
// At that point, unless a user spawned an infinite, stand-alone task
// that isn't monitoring `Shutdown`, all tasks should resolve. This
// means that all instances of the shared `Arc<Rocket>` are dropped and
// we can return the owned instance of `Rocket`.
//
// Unfortunately, the Hyper `server` future resolves as soon as it has
// finished processing requests without respect for ongoing responses.
// That is, `server` resolves even when there are running tasks that are
// generating a response. So, `server` resolving implies little to
// nothing about the state of connections. As a result, we depend on the
// timing of grace + mercy + some buffer to determine when all
// connections should be closed, thus all tasks should be complete, thus
// all references to `Arc<Rocket>` should be dropped and we can get back
// a unique reference.
info!("Shutting down. Waiting for shutdown fairings and pending I/O...");
tokio::spawn({
let rocket = rocket.clone();
async move { rocket.fairings.handle_shutdown(&*rocket).await }
});
Ok(())
}
let config = &rocket.config.shutdown;
let wait = Duration::from_micros(250);
for period in [wait, config.grace(), wait, config.mercy(), wait * 4] {
if Arc::strong_count(&rocket) == 1 { break }
sleep(period).await;
#[cfg(feature = "http3-preview")]
async fn serve3(self: Arc<Self>, listener: crate::listener::quic::QuicListener) -> Result<()> {
let rocket = self.clone();
let listener = Arc::new(listener.bounced());
while let Some(accept) = listener.accept().unless(rocket.shutdown()).await? {
let (listener, rocket) = (listener.clone(), rocket.clone());
spawn_inspect(|e: &io::Error| log_server_error(e), async move {
let mut stream = listener.connect(accept).io_unless(rocket.shutdown()).await?;
while let Some(mut conn) = stream.accept().io_unless(rocket.shutdown()).await? {
let rocket = rocket.clone();
spawn_inspect(|e: &io::Error| log_server_error(e), async move {
let meta = ConnectionMeta::from(&conn);
let rx = conn.rx.cancellable(rocket.shutdown.clone());
let response = rocket.clone()
.service(conn.parts, rx, None, ConnectionMeta::from(meta))
.map_err(io::Error::other)
.io_unless(rocket.shutdown.mercy.clone())
.await?;
let grace = rocket.shutdown.grace.clone();
match conn.tx.send_response(response).or(grace).await {
Left(result) => result,
Right(_) => Ok(conn.tx.cancel()),
}
});
}
Ok(())
});
}
match Arc::try_unwrap(rocket) {
Ok(rocket) => {
info!("Graceful shutdown completed successfully.");
Ok(rocket.into_ignite())
}
Err(rocket) => {
warn!("Shutdown failed: outstanding background I/O.");
Err(Error::new(ErrorKind::Shutdown(rocket)))
}
}
Ok(())
}
}

View File

@ -1,7 +1,6 @@
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use state::InitCell;
use yansi::Paint;
use crate::{Rocket, Request, Response, Orbit, Config};
@ -68,11 +67,18 @@ use crate::shield::*;
/// policy.
pub struct Shield {
/// Enabled policies where the key is the header name.
policies: HashMap<&'static UncasedStr, Box<dyn SubPolicy>>,
policies: HashMap<&'static UncasedStr, Header<'static>>,
/// Whether to enforce HSTS even though the user didn't enable it.
force_hsts: AtomicBool,
/// Headers pre-rendered at liftoff from the configured policies.
rendered: InitCell<Vec<Header<'static>>>,
}
impl Clone for Shield {
fn clone(&self) -> Self {
Self {
policies: self.policies.clone(),
force_hsts: AtomicBool::from(self.force_hsts.load(Ordering::Acquire)),
}
}
}
impl Default for Shield {
@ -111,7 +117,6 @@ impl Shield {
Shield {
policies: HashMap::new(),
force_hsts: AtomicBool::new(false),
rendered: InitCell::new(),
}
}
@ -129,8 +134,7 @@ impl Shield {
/// let shield = Shield::new().enable(NoSniff::default());
/// ```
pub fn enable<P: Policy>(mut self, policy: P) -> Self {
self.rendered = InitCell::new();
self.policies.insert(P::NAME.into(), Box::new(policy));
self.policies.insert(P::NAME.into(), policy.header());
self
}
@ -145,7 +149,6 @@ impl Shield {
/// let shield = Shield::default().disable::<NoSniff>();
/// ```
pub fn disable<P: Policy>(mut self) -> Self {
self.rendered = InitCell::new();
self.policies.remove(UncasedStr::new(P::NAME));
self
}
@ -172,20 +175,6 @@ impl Shield {
pub fn is_enabled<P: Policy>(&self) -> bool {
self.policies.contains_key(UncasedStr::new(P::NAME))
}
fn headers(&self) -> &[Header<'static>] {
self.rendered.get_or_init(|| {
let mut headers: Vec<_> = self.policies.values()
.map(|p| p.header())
.collect();
if self.force_hsts.load(Ordering::Acquire) {
headers.push(Policy::header(&Hsts::default()));
}
headers
})
}
}
#[crate::async_trait]
@ -198,7 +187,7 @@ impl Fairing for Shield {
}
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
let force_hsts = rocket.endpoint().is_tls()
let force_hsts = rocket.endpoints().all(|v| v.is_tls())
&& rocket.figment().profile() != Config::DEBUG_PROFILE
&& !self.is_enabled::<Hsts>();
@ -206,10 +195,10 @@ impl Fairing for Shield {
self.force_hsts.store(true, Ordering::Release);
}
if !self.headers().is_empty() {
if !self.policies.is_empty() {
info!("{}{}:", "🛡️ ".emoji(), "Shield".magenta());
for header in self.headers() {
for header in self.policies.values() {
info_!("{}: {}", header.name(), header.value().primary());
}
@ -224,7 +213,7 @@ impl Fairing for Shield {
async fn on_response<'r>(&self, _: &'r Request<'_>, response: &mut Response<'r>) {
// Set all of the headers in `self.policies` in `response` as long as
// the header is not already in the response.
for header in self.headers() {
for header in self.policies.values() {
if response.headers().contains(header.name()) {
warn!("Shield: response contains a '{}' header.", header.name());
warn_!("Refusing to overwrite existing header.");

View File

@ -6,60 +6,7 @@ use std::collections::HashSet;
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
/// A Unix signal for triggering graceful shutdown.
///
/// Each variant corresponds to a Unix process signal which can be used to
/// trigger a graceful shutdown. See [`Shutdown`] for details.
///
/// ## (De)serialization
///
/// A `Sig` variant serializes and deserializes as a lowercase string equal to
/// the name of the variant: `"alrm"` for [`Sig::Alrm`], `"chld"` for
/// [`Sig::Chld`], and so on.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[cfg_attr(nightly, doc(cfg(unix)))]
pub enum Sig {
/// The `SIGALRM` Unix signal.
Alrm,
/// The `SIGCHLD` Unix signal.
Chld,
/// The `SIGHUP` Unix signal.
Hup,
/// The `SIGINT` Unix signal.
Int,
/// The `SIGIO` Unix signal.
Io,
/// The `SIGPIPE` Unix signal.
Pipe,
/// The `SIGQUIT` Unix signal.
Quit,
/// The `SIGTERM` Unix signal.
Term,
/// The `SIGUSR1` Unix signal.
Usr1,
/// The `SIGUSR2` Unix signal.
Usr2
}
impl fmt::Display for Sig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Sig::Alrm => "SIGALRM",
Sig::Chld => "SIGCHLD",
Sig::Hup => "SIGHUP",
Sig::Int => "SIGINT",
Sig::Io => "SIGIO",
Sig::Pipe => "SIGPIPE",
Sig::Quit => "SIGQUIT",
Sig::Term => "SIGTERM",
Sig::Usr1 => "SIGUSR1",
Sig::Usr2 => "SIGUSR2",
};
s.fmt(f)
}
}
use crate::shutdown::Sig;
/// Graceful shutdown configuration.
///
@ -94,11 +41,13 @@ impl fmt::Display for Sig {
///
/// Once a shutdown is triggered, Rocket stops accepting new connections and
/// waits at most `grace` seconds before initiating connection shutdown.
/// Applications can `await` the [`Shutdown`](crate::Shutdown) future to detect
/// Applications can `await` the [`Shutdown`] future to detect
/// a shutdown and cancel any server-initiated I/O, such as from [infinite
/// responders](crate::response::stream#graceful-shutdown), to avoid abrupt I/O
/// cancellation.
///
/// [`Shutdown`]: crate::Shutdown
///
/// # Mercy Period
///
/// After the grace period has elapsed, Rocket initiates connection shutdown,
@ -125,7 +74,8 @@ impl fmt::Display for Sig {
/// prevent _buggy_ code, such as an unintended infinite loop or unknown use of
/// blocking I/O, from preventing shutdown.
///
/// This behavior can be disabled by setting [`Shutdown::force`] to `false`.
/// This behavior can be disabled by setting [`ShutdownConfig::force`] to
/// `false`.
///
/// # Example
///
@ -169,13 +119,13 @@ impl fmt::Display for Sig {
///
/// ```rust
/// # use rocket::figment::{Figment, providers::{Format, Toml}};
/// use rocket::config::{Config, Shutdown};
/// use rocket::config::{Config, ShutdownConfig};
///
/// #[cfg(unix)]
/// use rocket::config::Sig;
///
/// let config = Config {
/// shutdown: Shutdown {
/// shutdown: ShutdownConfig {
/// ctrlc: false,
/// #[cfg(unix)]
/// signals: {
@ -204,7 +154,7 @@ impl fmt::Display for Sig {
/// }
/// ```
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Shutdown {
pub struct ShutdownConfig {
/// Whether `ctrl-c` (`SIGINT`) initiates a server shutdown.
///
/// **default: `true`**
@ -245,9 +195,9 @@ pub struct Shutdown {
/// _always_ be done using a public constructor or update syntax:
///
/// ```rust
/// use rocket::config::Shutdown;
/// use rocket::config::ShutdownConfig;
///
/// let config = Shutdown {
/// let config = ShutdownConfig {
/// grace: 5,
/// mercy: 10,
/// ..Default::default()
@ -258,7 +208,7 @@ pub struct Shutdown {
pub __non_exhaustive: (),
}
impl fmt::Display for Shutdown {
impl fmt::Display for ShutdownConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ctrlc = {}, force = {}, ", self.ctrlc, self.force)?;
@ -276,9 +226,9 @@ impl fmt::Display for Shutdown {
}
}
impl Default for Shutdown {
impl Default for ShutdownConfig {
fn default() -> Self {
Shutdown {
ShutdownConfig {
ctrlc: true,
#[cfg(unix)]
signals: { let mut set = HashSet::new(); set.insert(Sig::Term); set },
@ -290,7 +240,7 @@ impl Default for Shutdown {
}
}
impl Shutdown {
impl ShutdownConfig {
pub(crate) fn grace(&self) -> Duration {
Duration::from_secs(self.grace as u64)
}

View File

@ -2,22 +2,22 @@ use std::future::Future;
use std::task::{Context, Poll};
use std::pin::Pin;
use futures::FutureExt;
use futures::{FutureExt, StreamExt};
use crate::shutdown::{ShutdownConfig, TripWire};
use crate::request::{FromRequest, Outcome, Request};
use crate::util::TripWire;
/// A request guard and future for graceful shutdown.
///
/// A server shutdown is manually requested by calling [`Shutdown::notify()`]
/// or, if enabled, through [automatic triggers] like `Ctrl-C`. Rocket will stop accepting new
/// requests, finish handling any pending requests, wait a grace period before
/// cancelling any outstanding I/O, and return `Ok()` to the caller of
/// [`Rocket::launch()`]. Graceful shutdown is configured via
/// [`config::Shutdown`](crate::config::Shutdown).
/// or, if enabled, through [automatic triggers] like `Ctrl-C`. Rocket will stop
/// accepting new requests, finish handling any pending requests, wait a grace
/// period before cancelling any outstanding I/O, and return `Ok()` to the
/// caller of [`Rocket::launch()`]. Graceful shutdown is configured via
/// [`ShutdownConfig`](crate::config::ShutdownConfig).
///
/// [`Rocket::launch()`]: crate::Rocket::launch()
/// [automatic triggers]: crate::config::Shutdown#triggers
/// [automatic triggers]: crate::shutdown::Shutdown#triggers
///
/// # Detecting Shutdown
///
@ -65,15 +65,30 @@ use crate::util::TripWire;
/// ```
#[derive(Debug, Clone)]
#[must_use = "`Shutdown` does nothing unless polled or `notify`ed"]
pub struct Shutdown(pub(crate) TripWire);
pub struct Shutdown {
wire: TripWire,
}
#[derive(Debug, Clone)]
pub struct Stages {
pub start: Shutdown,
pub grace: Shutdown,
pub mercy: Shutdown,
}
impl Shutdown {
fn new() -> Self {
Shutdown {
wire: TripWire::new(),
}
}
/// Notify the application to shut down gracefully.
///
/// This function returns immediately; pending requests will continue to run
/// until completion or expiration of the grace period, which ever comes
/// first, before the actual shutdown occurs. The grace period can be
/// configured via [`Shutdown::grace`](crate::config::Shutdown::grace).
/// configured via [`Shutdown::grace`](crate::config::ShutdownConfig::grace).
///
/// ```rust
/// # use rocket::*;
@ -85,9 +100,37 @@ impl Shutdown {
/// "Shutting down..."
/// }
/// ```
#[inline]
pub fn notify(self) {
self.0.trip();
#[inline(always)]
pub fn notify(&self) {
self.wire.trip();
}
/// Returns `true` if `Shutdown::notify()` has already been called.
///
/// # Example
///
/// ```rust
/// # use rocket::*;
/// use rocket::Shutdown;
///
/// #[get("/shutdown")]
/// fn shutdown(shutdown: Shutdown) {
/// shutdown.notify();
/// assert!(shutdown.notified());
/// }
/// ```
#[must_use]
#[inline(always)]
pub fn notified(&self) -> bool {
self.wire.tripped()
}
}
impl Future for Shutdown {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.wire.poll_unpin(cx)
}
}
@ -101,11 +144,41 @@ impl<'r> FromRequest<'r> for Shutdown {
}
}
impl Future for Shutdown {
type Output = ();
impl Stages {
pub fn new() -> Self {
Stages {
start: Shutdown::new(),
grace: Shutdown::new(),
mercy: Shutdown::new(),
}
}
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_unpin(cx)
pub(crate) fn spawn_listener(&self, config: &ShutdownConfig) {
use futures::stream;
use futures::future::{select, Either};
let mut signal = match config.signal_stream() {
Some(stream) => Either::Left(stream.chain(stream::pending())),
None => Either::Right(stream::pending()),
};
let start = self.start.clone();
let (grace, grace_duration) = (self.grace.clone(), config.grace());
let (mercy, mercy_duration) = (self.mercy.clone(), config.mercy());
tokio::spawn(async move {
if let Either::Left((sig, start)) = select(signal.next(), start).await {
warn!("Received {}. Shutdown started.", sig.unwrap());
start.notify();
}
tokio::time::sleep(grace_duration).await;
warn!("Shutdown grace period elapsed. Shutting down I/O.");
grace.notify();
tokio::time::sleep(mercy_duration).await;
warn!("Mercy period elapsed. Terminating I/O.");
mercy.notify();
});
}
}

View File

@ -0,0 +1,13 @@
//! Shutdown configuration and notification handle.
mod tripwire;
mod handle;
mod sig;
mod config;
pub(crate) use tripwire::TripWire;
pub(crate) use handle::Stages;
pub use config::ShutdownConfig;
pub use handle::Shutdown;
pub use sig::Sig;

View File

@ -0,0 +1,58 @@
use std::fmt;
use serde::{Deserialize, Serialize};
/// A Unix signal for triggering graceful shutdown.
///
/// Each variant corresponds to a Unix process signal which can be used to
/// trigger a graceful shutdown. See [`Shutdown`](crate::Shutdown) for details.
///
/// ## (De)serialization
///
/// A `Sig` variant serializes and deserializes as a lowercase string equal to
/// the name of the variant: `"alrm"` for [`Sig::Alrm`], `"chld"` for
/// [`Sig::Chld`], and so on.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[cfg_attr(nightly, doc(cfg(unix)))]
pub enum Sig {
/// The `SIGALRM` Unix signal.
Alrm,
/// The `SIGCHLD` Unix signal.
Chld,
/// The `SIGHUP` Unix signal.
Hup,
/// The `SIGINT` Unix signal.
Int,
/// The `SIGIO` Unix signal.
Io,
/// The `SIGPIPE` Unix signal.
Pipe,
/// The `SIGQUIT` Unix signal.
Quit,
/// The `SIGTERM` Unix signal.
Term,
/// The `SIGUSR1` Unix signal.
Usr1,
/// The `SIGUSR2` Unix signal.
Usr2
}
impl fmt::Display for Sig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Sig::Alrm => "SIGALRM",
Sig::Chld => "SIGCHLD",
Sig::Hup => "SIGHUP",
Sig::Int => "SIGINT",
Sig::Io => "SIGIO",
Sig::Pipe => "SIGPIPE",
Sig::Quit => "SIGQUIT",
Sig::Term => "SIGTERM",
Sig::Usr1 => "SIGUSR1",
Sig::Usr2 => "SIGUSR2",
};
s.fmt(f)
}
}

View File

@ -3,6 +3,8 @@ use std::{ops::Deref, pin::Pin, future::Future};
use std::task::{Context, Poll};
use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
use futures::future::FusedFuture;
use tokio::sync::futures::Notified;
use tokio::sync::Notify;
#[doc(hidden)]
@ -15,7 +17,7 @@ pub struct State {
pub struct TripWire {
state: Arc<State>,
// `Notified` is `!Unpin`. Even if we could name it, we'd need to pin it.
event: Option<Pin<Box<dyn Future<Output = ()> + Send + Sync>>>,
event: Option<Pin<Box<Notified<'static>>>>,
}
impl Deref for TripWire {
@ -35,6 +37,13 @@ impl Clone for TripWire {
}
}
impl Drop for TripWire {
fn drop(&mut self) {
// SAFETY: Ensure we drop the self-reference before `self`.
self.event = None;
}
}
impl fmt::Debug for TripWire {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TripWire")
@ -47,35 +56,22 @@ impl Future for TripWire {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.tripped.load(Ordering::Acquire) {
if self.tripped() {
self.event = None;
return Poll::Ready(());
}
if self.event.is_none() {
let state = self.state.clone();
self.event = Some(Box::pin(async move {
let notified = state.notify.notified();
notified.await
}));
let notified = self.state.notify.notified();
// SAFETY: This is a self reference to the `state`.
self.event = Some(Box::pin(unsafe { std::mem::transmute(notified) }));
}
if let Some(ref mut event) = self.event {
if event.as_mut().poll(cx).is_ready() {
// We need to call `trip()` to avoid a race condition where:
// 1) many trip wires have seen !self.tripped but have not
// polled for `self.event` yet, so are not subscribed
// 2) trip() is called, adding a permit to `event`
// 3) some trip wires poll `event` for the first time
// 4) one of those wins, returns `Ready()`
// 5) the rest return pending
//
// Without this `self.trip()` those will never be awoken. With
// the call to self.trip(), those that made it to poll() in 3)
// will be awoken by `notify_waiters()`. For those the didn't,
// one will be awoken by `notify_one()`, which will in-turn call
// self.trip(), awaking more until there are no more to awake.
self.trip();
// The order here is important! We need to know:
// !self.tripped() => not notified == notified => self.tripped()
if event.as_mut().poll(cx).is_ready() || self.tripped() {
self.event = None;
return Poll::Ready(());
}
@ -85,6 +81,12 @@ impl Future for TripWire {
}
}
impl FusedFuture for TripWire {
fn is_terminated(&self) -> bool {
self.tripped()
}
}
impl TripWire {
pub fn new() -> Self {
TripWire {
@ -99,7 +101,6 @@ impl TripWire {
pub fn trip(&self) {
self.tripped.store(true, Ordering::Release);
self.notify.notify_waiters();
self.notify.notify_one();
}
#[inline(always)]

View File

@ -427,7 +427,7 @@ impl TlsConfig {
}
pub fn validate(&self) -> Result<(), crate::tls::Error> {
self.acceptor().map(|_| ())
self.server_config().map(|_| ())
}
}

View File

@ -1,5 +1,4 @@
mod chain;
mod tripwire;
mod reader_stream;
mod join;
@ -7,6 +6,55 @@ mod join;
pub mod unix;
pub use chain::Chain;
pub use tripwire::TripWire;
pub use reader_stream::ReaderStream;
pub use join::join;
#[track_caller]
pub fn spawn_inspect<E, F, Fut>(or: F, future: Fut)
where F: FnOnce(&E) + Send + Sync + 'static,
E: Send + Sync + 'static,
Fut: std::future::Future<Output = Result<(), E>> + Send + 'static,
{
use futures::TryFutureExt;
tokio::spawn(future.inspect_err(or));
}
use std::io;
use std::pin::pin;
use std::future::Future;
use futures::future::{select, Either};
pub trait FutureExt: Future + Sized {
/// Await `self` or `other`, whichever finishes first.
async fn or<B: Future>(self, other: B) -> Either<Self::Output, B::Output> {
match futures::future::select(pin!(self), pin!(other)).await {
Either::Left((v, _)) => Either::Left(v),
Either::Right((v, _)) => Either::Right(v),
}
}
/// Await `self` unless `trigger` completes. Returns `Ok(Some(T))` if `self`
/// completes successfully before `trigger`, `Err(E)` if `self` completes
/// unsuccessfully, and `Ok(None)` if `trigger` completes before `self`.
async fn unless<T, E, K: Future>(self, trigger: K) -> Result<Option<T>, E>
where Self: Future<Output = Result<T, E>>
{
match select(pin!(self), pin!(trigger)).await {
Either::Left((v, _)) => Ok(Some(v?)),
Either::Right((_, _)) => Ok(None),
}
}
/// Await `self` unless `trigger` completes. If `self` completes before
/// `trigger`, returns the result. Otherwise, always returns an `Err`.
async fn io_unless<T, K: Future>(self, trigger: K) -> std::io::Result<T>
where Self: Future<Output = std::io::Result<T>>
{
match select(pin!(self), pin!(trigger)).await {
Either::Left((v, _)) => v,
Either::Right((_, _)) => Err(io::Error::other("I/O terminated")),
}
}
}
impl<F: Future + Sized> FutureExt for F { }

View File

@ -10,7 +10,8 @@ async fn on_ignite_fairing_can_inspect_port() {
let rocket = rocket::custom(Config::debug_default())
.attach(AdHoc::on_liftoff("Send Port -> Channel", move |rocket| {
Box::pin(async move {
tx.send(rocket.endpoint().tcp().unwrap().port()).unwrap();
let tcp = rocket.endpoints().find_map(|v| v.tcp());
tx.send(tcp.unwrap().port()).expect("send okay");
})
}));

View File

@ -21,24 +21,24 @@ is configured with. This means that no matter which configuration provider
Rocket is asked to use, it must be able to read the following configuration
values:
| key | kind | description | debug/release default |
|----------------------|-------------------|-------------------------------------------------|-------------------------|
| `address` | `IpAddr` | IP address to serve on. | `127.0.0.1` |
| `port` | `u16` | Port to serve on. | `8000` |
| `workers`* | `usize` | Number of threads to use for executing futures. | cpu core count |
| `max_blocking`* | `usize` | Limit on threads to start for blocking tasks. | `512` |
| `ident` | `string`, `false` | If and how to identify via the `Server` header. | `"Rocket"` |
| `ip_header` | `string`, `false` | IP header to inspect to get [client's real IP]. | `"X-Real-IP"` |
| `proxy_proto_header` | `string`, `false` | Header identifying [client to proxy protocol]. | `None` |
| `keep_alive` | `u32` | Keep-alive timeout seconds; disabled when `0`. | `5` |
| `log_level` | [`LogLevel`] | Max level to log. (off/normal/debug/critical) | `normal`/`critical` |
| `cli_colors` | [`CliColors`] | Whether to use colors and emoji when logging. | `"auto"` |
| `secret_key` | [`SecretKey`] | Secret key for signing and encrypting values. | `None` |
| `tls` | [`TlsConfig`] | TLS configuration, if any. | `None` |
| `limits` | [`Limits`] | Streaming read size limits. | [`Limits::default()`] |
| `limits.$name` | `&str`/`uint` | Read limit for `$name`. | form = "32KiB" |
| `ctrlc` | `bool` | Whether `ctrl-c` initiates a server shutdown. | `true` |
| `shutdown`* | [`Shutdown`] | Graceful shutdown configuration. | [`Shutdown::default()`] |
| key | kind | description | debug/release default |
|----------------------|--------------------|-------------------------------------------------|-------------------------------|
| `address` | `IpAddr` | IP address to serve on. | `127.0.0.1` |
| `port` | `u16` | Port to serve on. | `8000` |
| `workers`* | `usize` | Number of threads to use for executing futures. | cpu core count |
| `max_blocking`* | `usize` | Limit on threads to start for blocking tasks. | `512` |
| `ident` | `string`, `false` | If and how to identify via the `Server` header. | `"Rocket"` |
| `ip_header` | `string`, `false` | IP header to inspect to get [client's real IP]. | `"X-Real-IP"` |
| `proxy_proto_header` | `string`, `false` | Header identifying [client to proxy protocol]. | `None` |
| `keep_alive` | `u32` | Keep-alive timeout seconds; disabled when `0`. | `5` |
| `log_level` | [`LogLevel`] | Max level to log. (off/normal/debug/critical) | `normal`/`critical` |
| `cli_colors` | [`CliColors`] | Whether to use colors and emoji when logging. | `"auto"` |
| `secret_key` | [`SecretKey`] | Secret key for signing and encrypting values. | `None` |
| `tls` | [`TlsConfig`] | TLS configuration, if any. | `None` |
| `limits` | [`Limits`] | Streaming read size limits. | [`Limits::default()`] |
| `limits.$name` | `&str`/`uint` | Read limit for `$name`. | form = "32KiB" |
| `ctrlc` | `bool` | Whether `ctrl-c` initiates a server shutdown. | `true` |
| `shutdown`* | [`ShutdownConfig`] | Graceful shutdown configuration. | [`ShutdownConfig::default()`] |
<small>* Note: the `workers`, `max_blocking`, and `shutdown.force` configuration
@ -77,8 +77,8 @@ profile supplant any values with the same name in any profile.
[`SecretKey`]: @api/master/rocket/config/struct.SecretKey.html
[`CliColors`]: @api/master/rocket/config/enum.CliColors.html
[`TlsConfig`]: @api/master/rocket/tls/struct.TlsConfig.html
[`Shutdown`]: @api/master/rocket/config/struct.Shutdown.html
[`Shutdown::default()`]: @api/master/rocket/config/struct.Shutdown.html#fields
[`ShutdownConfig`]: @api/master/rocket/shutdown/struct.ShutdownConfig.html
[`ShutdownConfig::default()`]: @api/master/rocket/shutdown/struct.ShutdownConfig.html#fields
## Default Provider

View File

@ -6,5 +6,5 @@ edition = "2021"
publish = false
[dependencies]
rocket = { path = "../../core/lib", features = ["tls", "mtls", "secrets"] }
rocket = { path = "../../core/lib", features = ["tls", "mtls", "secrets", "http3-preview"] }
yansi = "1.0.1"

View File

@ -1,10 +1,12 @@
#[macro_use] extern crate rocket;
#[macro_use]
extern crate rocket;
#[cfg(test)]
mod tests;
mod redirector;
use rocket::mtls::Certificate;
use rocket::listener::Endpoint;
#[get("/")]
fn mutual(cert: Certificate<'_>) -> String {
@ -12,8 +14,11 @@ fn mutual(cert: Certificate<'_>) -> String {
}
#[get("/", rank = 2)]
fn hello() -> &'static str {
"Hello, world!"
fn hello(endpoint: Option<&Endpoint>) -> String {
match endpoint {
Some(endpoint) => format!("Hello, {endpoint}!"),
None => "Hello, world!".into(),
}
}
#[launch]

View File

@ -74,7 +74,7 @@ impl Fairing for Redirector {
}
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
let Some(tls_addr) = rocket.endpoint().tls().and_then(|tls| tls.tcp()) else {
let Some(tls_addr) = rocket.endpoints().find_map(|e| e.tls()?.tcp()) else {
info!("{}{}", "🔒 ".mask(), "HTTP -> HTTPS Redirector:".magenta());
warn_!("Main instance is not being served over TLS/TCP.");
warn_!("Redirector refusing to start.");

View File

@ -128,6 +128,7 @@ function test_core() {
FEATURES=(
tokio-macros
http2
http3-preview
secrets
tls
mtls