diff --git a/benchmarks/src/routing.rs b/benchmarks/src/routing.rs index 89dd4e5c..417615cd 100644 --- a/benchmarks/src/routing.rs +++ b/benchmarks/src/routing.rs @@ -82,7 +82,7 @@ fn client(routes: Vec) -> Client { profile: Config::RELEASE_PROFILE, log_level: rocket::config::LogLevel::Off, cli_colors: config::CliColors::Never, - shutdown: config::Shutdown { + shutdown: config::ShutdownConfig { ctrlc: false, #[cfg(unix)] signals: HashSet::new(), diff --git a/contrib/sync_db_pools/lib/tests/shutdown.rs b/contrib/sync_db_pools/lib/tests/shutdown.rs index 8f76b0fa..d338862e 100644 --- a/contrib/sync_db_pools/lib/tests/shutdown.rs +++ b/contrib/sync_db_pools/lib/tests/shutdown.rs @@ -12,6 +12,7 @@ mod sqlite_shutdown_test { let options = map!["url" => ":memory:"]; let config = Figment::from(rocket::Config::debug_default()) + .merge(("port", 0)) .merge(("databases", map!["test" => &options])); rocket::custom(config).attach(Pool::fairing()) diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index cc38bb1b..2c433d66 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -22,6 +22,7 @@ all-features = true [features] default = ["http2", "tokio-macros"] http2 = ["hyper/http2", "hyper-util/http2"] +http3-preview = ["s2n-quic", "s2n-quic-h3", "tls"] secrets = ["cookie/private", "cookie/key-expansion"] json = ["serde_json"] msgpack = ["rmp-serde"] @@ -76,8 +77,7 @@ futures = { version = "0.3.30", default-features = false, features = ["std"] } state = "0.6" [dependencies.hyper-util] -git = "https://github.com/SergioBenitez/hyper-util.git" -branch = "fix-readversion" +version = "0.1.3" default-features = false features = ["http1", "server", "tokio"] @@ -99,6 +99,16 @@ version = "0.6.0-dev" path = "../http" features = ["serde"] +[dependencies.s2n-quic] +version = "1.32" +default-features = false +features = ["provider-address-token-default", "provider-tls-rustls"] +optional = true + +[dependencies.s2n-quic-h3] +git = "https://github.com/SergioBenitez/s2n-quic-h3.git" +optional = true + [target.'cfg(unix)'.dependencies] libc = "0.2.149" diff --git a/core/lib/src/config/config.rs b/core/lib/src/config/config.rs index e208944c..ea8c8611 100644 --- a/core/lib/src/config/config.rs +++ b/core/lib/src/config/config.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use yansi::{Paint, Style, Color::Primary}; use crate::log::PaintExt; -use crate::config::{LogLevel, Shutdown, Ident, CliColors}; +use crate::config::{LogLevel, ShutdownConfig, Ident, CliColors}; use crate::request::{self, Request, FromRequest}; use crate::http::uncased::Uncased; use crate::data::Limits; @@ -120,8 +120,8 @@ pub struct Config { #[cfg_attr(nightly, doc(cfg(feature = "secrets")))] #[serde(serialize_with = "SecretKey::serialize_zero")] pub secret_key: SecretKey, - /// Graceful shutdown configuration. **(default: [`Shutdown::default()`])** - pub shutdown: Shutdown, + /// Graceful shutdown configuration. **(default: [`ShutdownConfig::default()`])** + pub shutdown: ShutdownConfig, /// Max level to log. **(default: _debug_ `normal` / _release_ `critical`)** pub log_level: LogLevel, /// Whether to use colors and emoji when logging. **(default: @@ -200,7 +200,7 @@ impl Config { keep_alive: 5, #[cfg(feature = "secrets")] secret_key: SecretKey::zero(), - shutdown: Shutdown::default(), + shutdown: ShutdownConfig::default(), log_level: LogLevel::Normal, cli_colors: CliColors::Auto, __non_exhaustive: (), @@ -408,9 +408,10 @@ impl Config { #[cfg(feature = "secrets")] { launch_meta_!("secret key: {}", self.secret_key.paint(VAL)); if !self.secret_key.is_provided() { - warn!("secrets enabled without a stable `secret_key`"); - launch_meta_!("disable `secrets` feature or configure a `secret_key`"); - launch_meta_!("this becomes an {} in non-debug profiles", "error".red()); + warn!("secrets enabled without configuring a stable `secret_key`"); + warn_!("private/signed cookies will become unreadable after restarting"); + launch_meta_!("disable the `secrets` feature or configure a `secret_key`"); + launch_meta_!("this becomes a {} in non-debug profiles", "hard error".red()); } } } diff --git a/core/lib/src/config/mod.rs b/core/lib/src/config/mod.rs index 86481af1..9f07e919 100644 --- a/core/lib/src/config/mod.rs +++ b/core/lib/src/config/mod.rs @@ -113,422 +113,35 @@ #[macro_use] mod ident; mod config; -mod shutdown; mod cli_colors; mod http_header; +#[cfg(test)] +mod tests; + +pub use ident::Ident; +pub use config::Config; +pub use cli_colors::CliColors; + +pub use crate::log::LogLevel; +pub use crate::shutdown::ShutdownConfig; + +#[cfg(feature = "tls")] +pub use crate::tls::TlsConfig; + +#[cfg(feature = "mtls")] +pub use crate::mtls::MtlsConfig; #[cfg(feature = "secrets")] mod secret_key; -#[doc(hidden)] -pub use config::{pretty_print_error, bail_with_config_error}; +#[cfg(unix)] +pub use crate::shutdown::Sig; -pub use config::Config; -pub use crate::log::LogLevel; -pub use shutdown::Shutdown; -pub use ident::Ident; -pub use cli_colors::CliColors; +#[cfg(unix)] +pub use crate::listener::unix::UdsConfig; #[cfg(feature = "secrets")] pub use secret_key::SecretKey; -#[cfg(unix)] -pub use shutdown::Sig; - -#[cfg(test)] -mod tests { - use figment::{Figment, Profile}; - use pretty_assertions::assert_eq; - - use crate::log::LogLevel; - use crate::data::{Limits, ToByteUnit}; - use crate::config::{Config, CliColors}; - - #[test] - fn test_figment_is_default() { - figment::Jail::expect_with(|_| { - let mut default: Config = Config::figment().extract().unwrap(); - default.profile = Config::default().profile; - assert_eq!(default, Config::default()); - Ok(()) - }); - } - - #[test] - fn test_default_round_trip() { - figment::Jail::expect_with(|_| { - let original = Config::figment(); - let roundtrip = Figment::from(Config::from(&original)); - for figment in &[original, roundtrip] { - let config = Config::from(figment); - assert_eq!(config, Config::default()); - } - - Ok(()) - }); - } - - #[test] - fn test_profile_env() { - figment::Jail::expect_with(|jail| { - jail.set_env("ROCKET_PROFILE", "debug"); - let figment = Config::figment(); - assert_eq!(figment.profile(), "debug"); - - jail.set_env("ROCKET_PROFILE", "release"); - let figment = Config::figment(); - assert_eq!(figment.profile(), "release"); - - jail.set_env("ROCKET_PROFILE", "random"); - let figment = Config::figment(); - assert_eq!(figment.profile(), "random"); - - Ok(()) - }); - } - - #[test] - fn test_toml_file() { - figment::Jail::expect_with(|jail| { - jail.create_file("Rocket.toml", r#" - [default] - ident = "Something Cool" - workers = 20 - keep_alive = 10 - log_level = "off" - cli_colors = 0 - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - workers: 20, - ident: ident!("Something Cool"), - keep_alive: 10, - log_level: LogLevel::Off, - cli_colors: CliColors::Never, - ..Config::default() - }); - - jail.create_file("Rocket.toml", r#" - [global] - ident = "Something Else Cool" - workers = 20 - keep_alive = 10 - log_level = "off" - cli_colors = 0 - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - workers: 20, - ident: ident!("Something Else Cool"), - keep_alive: 10, - log_level: LogLevel::Off, - cli_colors: CliColors::Never, - ..Config::default() - }); - - jail.set_env("ROCKET_CONFIG", "Other.toml"); - jail.create_file("Other.toml", r#" - [default] - workers = 20 - keep_alive = 10 - log_level = "off" - cli_colors = 0 - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - workers: 20, - keep_alive: 10, - log_level: LogLevel::Off, - cli_colors: CliColors::Never, - ..Config::default() - }); - - Ok(()) - }); - } - - #[test] - fn test_cli_colors() { - figment::Jail::expect_with(|jail| { - jail.create_file("Rocket.toml", r#" - [default] - cli_colors = "never" - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Never); - - jail.create_file("Rocket.toml", r#" - [default] - cli_colors = "auto" - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Auto); - - jail.create_file("Rocket.toml", r#" - [default] - cli_colors = "always" - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Always); - - jail.create_file("Rocket.toml", r#" - [default] - cli_colors = true - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Auto); - - jail.create_file("Rocket.toml", r#" - [default] - cli_colors = false - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Never); - - jail.create_file("Rocket.toml", r#"[default]"#)?; - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Auto); - - jail.create_file("Rocket.toml", r#" - [default] - cli_colors = 1 - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Auto); - - jail.create_file("Rocket.toml", r#" - [default] - cli_colors = 0 - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Never); - - jail.set_env("ROCKET_CLI_COLORS", 1); - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Auto); - - jail.set_env("ROCKET_CLI_COLORS", 0); - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Never); - - jail.set_env("ROCKET_CLI_COLORS", true); - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Auto); - - jail.set_env("ROCKET_CLI_COLORS", false); - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Never); - - jail.set_env("ROCKET_CLI_COLORS", "always"); - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Always); - - jail.set_env("ROCKET_CLI_COLORS", "NEveR"); - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Never); - - jail.set_env("ROCKET_CLI_COLORS", "auTO"); - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Auto); - - Ok(()) - }) - } - - #[test] - fn test_profiles_merge() { - figment::Jail::expect_with(|jail| { - jail.create_file("Rocket.toml", r#" - [default.limits] - stream = "50kb" - - [global] - limits = { forms = "2kb" } - - [debug.limits] - file = "100kb" - "#)?; - - jail.set_env("ROCKET_PROFILE", "unknown"); - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - profile: Profile::const_new("unknown"), - limits: Limits::default() - .limit("stream", 50.kilobytes()) - .limit("forms", 2.kilobytes()), - ..Config::default() - }); - - jail.set_env("ROCKET_PROFILE", "debug"); - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - profile: Profile::const_new("debug"), - limits: Limits::default() - .limit("stream", 50.kilobytes()) - .limit("forms", 2.kilobytes()) - .limit("file", 100.kilobytes()), - ..Config::default() - }); - - Ok(()) - }); - } - - #[test] - fn test_env_vars_merge() { - use crate::config::{Ident, Shutdown}; - - figment::Jail::expect_with(|jail| { - jail.set_env("ROCKET_KEEP_ALIVE", 9999); - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - keep_alive: 9999, - ..Config::default() - }); - - jail.set_env("ROCKET_SHUTDOWN", r#"{grace=7}"#); - let first_figment = Config::figment(); - jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=10}"#); - let prev_figment = Config::figment().join(&first_figment); - let config = Config::from(&prev_figment); - assert_eq!(config, Config { - keep_alive: 9999, - shutdown: Shutdown { grace: 7, mercy: 10, ..Default::default() }, - ..Config::default() - }); - - jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=20}"#); - let config = Config::from(Config::figment().join(&prev_figment)); - assert_eq!(config, Config { - keep_alive: 9999, - shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() }, - ..Config::default() - }); - - jail.set_env("ROCKET_LIMITS", r#"{stream=100kiB}"#); - let config = Config::from(Config::figment().join(&prev_figment)); - assert_eq!(config, Config { - keep_alive: 9999, - shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() }, - limits: Limits::default().limit("stream", 100.kibibytes()), - ..Config::default() - }); - - jail.set_env("ROCKET_IDENT", false); - let config = Config::from(Config::figment().join(&prev_figment)); - assert_eq!(config, Config { - keep_alive: 9999, - shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() }, - limits: Limits::default().limit("stream", 100.kibibytes()), - ident: Ident::none(), - ..Config::default() - }); - - Ok(()) - }); - } - - #[test] - fn test_precedence() { - figment::Jail::expect_with(|jail| { - jail.create_file("Rocket.toml", r#" - [global.limits] - forms = "1mib" - stream = "50kb" - file = "100kb" - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - limits: Limits::default() - .limit("forms", 1.mebibytes()) - .limit("stream", 50.kilobytes()) - .limit("file", 100.kilobytes()), - ..Config::default() - }); - - jail.set_env("ROCKET_LIMITS", r#"{stream=3MiB,capture=2MiB}"#); - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - limits: Limits::default() - .limit("file", 100.kilobytes()) - .limit("forms", 1.mebibytes()) - .limit("stream", 3.mebibytes()) - .limit("capture", 2.mebibytes()), - ..Config::default() - }); - - jail.set_env("ROCKET_PROFILE", "foo"); - let val: Result = Config::figment().extract_inner("profile"); - assert!(val.is_err()); - - Ok(()) - }); - } - - #[test] - #[cfg(feature = "secrets")] - #[should_panic] - fn test_err_on_non_debug_and_no_secret_key() { - figment::Jail::expect_with(|jail| { - jail.set_env("ROCKET_PROFILE", "release"); - let rocket = crate::custom(Config::figment()); - let _result = crate::local::blocking::Client::untracked(rocket); - Ok(()) - }); - } - - #[test] - #[cfg(feature = "secrets")] - #[should_panic] - fn test_err_on_non_debug2_and_no_secret_key() { - figment::Jail::expect_with(|jail| { - jail.set_env("ROCKET_PROFILE", "boop"); - let rocket = crate::custom(Config::figment()); - let _result = crate::local::blocking::Client::tracked(rocket); - Ok(()) - }); - } - - #[test] - fn test_no_err_on_debug_and_no_secret_key() { - figment::Jail::expect_with(|jail| { - jail.set_env("ROCKET_PROFILE", "debug"); - let figment = Config::figment(); - assert!(crate::local::blocking::Client::untracked(crate::custom(&figment)).is_ok()); - crate::async_main(async { - let rocket = crate::custom(&figment); - assert!(crate::local::asynchronous::Client::tracked(rocket).await.is_ok()); - }); - - Ok(()) - }); - } - - #[test] - fn test_no_err_on_release_and_custom_secret_key() { - figment::Jail::expect_with(|jail| { - jail.set_env("ROCKET_PROFILE", "release"); - let key = "Bx4Gb+aSIfuoEyMHD4DvNs92+wmzfQK98qc6MiwyPY4="; - let figment = Config::figment().merge(("secret_key", key)); - - assert!(crate::local::blocking::Client::tracked(crate::custom(&figment)).is_ok()); - crate::async_main(async { - let rocket = crate::custom(&figment); - assert!(crate::local::asynchronous::Client::untracked(rocket).await.is_ok()); - }); - - Ok(()) - }); - } -} +#[doc(hidden)] +pub use config::{pretty_print_error, bail_with_config_error}; diff --git a/core/lib/src/config/tests.rs b/core/lib/src/config/tests.rs new file mode 100644 index 00000000..b5e3429f --- /dev/null +++ b/core/lib/src/config/tests.rs @@ -0,0 +1,394 @@ +use figment::{Figment, Profile}; +use pretty_assertions::assert_eq; + +use crate::log::LogLevel; +use crate::data::{Limits, ToByteUnit}; +use crate::config::{Config, CliColors}; + +#[test] +fn test_figment_is_default() { + figment::Jail::expect_with(|_| { + let mut default: Config = Config::figment().extract().unwrap(); + default.profile = Config::default().profile; + assert_eq!(default, Config::default()); + Ok(()) + }); +} + +#[test] +fn test_default_round_trip() { + figment::Jail::expect_with(|_| { + let original = Config::figment(); + let roundtrip = Figment::from(Config::from(&original)); + for figment in &[original, roundtrip] { + let config = Config::from(figment); + assert_eq!(config, Config::default()); + } + + Ok(()) + }); +} + +#[test] +fn test_profile_env() { + figment::Jail::expect_with(|jail| { + jail.set_env("ROCKET_PROFILE", "debug"); + let figment = Config::figment(); + assert_eq!(figment.profile(), "debug"); + + jail.set_env("ROCKET_PROFILE", "release"); + let figment = Config::figment(); + assert_eq!(figment.profile(), "release"); + + jail.set_env("ROCKET_PROFILE", "random"); + let figment = Config::figment(); + assert_eq!(figment.profile(), "random"); + + Ok(()) + }); +} + +#[test] +fn test_toml_file() { + figment::Jail::expect_with(|jail| { + jail.create_file("Rocket.toml", r#" + [default] + ident = "Something Cool" + workers = 20 + keep_alive = 10 + log_level = "off" + cli_colors = 0 + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config, Config { + workers: 20, + ident: ident!("Something Cool"), + keep_alive: 10, + log_level: LogLevel::Off, + cli_colors: CliColors::Never, + ..Config::default() + }); + + jail.create_file("Rocket.toml", r#" + [global] + ident = "Something Else Cool" + workers = 20 + keep_alive = 10 + log_level = "off" + cli_colors = 0 + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config, Config { + workers: 20, + ident: ident!("Something Else Cool"), + keep_alive: 10, + log_level: LogLevel::Off, + cli_colors: CliColors::Never, + ..Config::default() + }); + + jail.set_env("ROCKET_CONFIG", "Other.toml"); + jail.create_file("Other.toml", r#" + [default] + workers = 20 + keep_alive = 10 + log_level = "off" + cli_colors = 0 + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config, Config { + workers: 20, + keep_alive: 10, + log_level: LogLevel::Off, + cli_colors: CliColors::Never, + ..Config::default() + }); + + Ok(()) + }); +} + +#[test] +fn test_cli_colors() { + figment::Jail::expect_with(|jail| { + jail.create_file("Rocket.toml", r#" + [default] + cli_colors = "never" + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Never); + + jail.create_file("Rocket.toml", r#" + [default] + cli_colors = "auto" + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Auto); + + jail.create_file("Rocket.toml", r#" + [default] + cli_colors = "always" + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Always); + + jail.create_file("Rocket.toml", r#" + [default] + cli_colors = true + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Auto); + + jail.create_file("Rocket.toml", r#" + [default] + cli_colors = false + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Never); + + jail.create_file("Rocket.toml", r#"[default]"#)?; + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Auto); + + jail.create_file("Rocket.toml", r#" + [default] + cli_colors = 1 + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Auto); + + jail.create_file("Rocket.toml", r#" + [default] + cli_colors = 0 + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Never); + + jail.set_env("ROCKET_CLI_COLORS", 1); + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Auto); + + jail.set_env("ROCKET_CLI_COLORS", 0); + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Never); + + jail.set_env("ROCKET_CLI_COLORS", true); + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Auto); + + jail.set_env("ROCKET_CLI_COLORS", false); + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Never); + + jail.set_env("ROCKET_CLI_COLORS", "always"); + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Always); + + jail.set_env("ROCKET_CLI_COLORS", "NEveR"); + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Never); + + jail.set_env("ROCKET_CLI_COLORS", "auTO"); + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Auto); + + Ok(()) + }) +} + +#[test] +fn test_profiles_merge() { + figment::Jail::expect_with(|jail| { + jail.create_file("Rocket.toml", r#" + [default.limits] + stream = "50kb" + + [global] + limits = { forms = "2kb" } + + [debug.limits] + file = "100kb" + "#)?; + + jail.set_env("ROCKET_PROFILE", "unknown"); + let config = Config::from(Config::figment()); + assert_eq!(config, Config { + profile: Profile::const_new("unknown"), + limits: Limits::default() + .limit("stream", 50.kilobytes()) + .limit("forms", 2.kilobytes()), + ..Config::default() + }); + + jail.set_env("ROCKET_PROFILE", "debug"); + let config = Config::from(Config::figment()); + assert_eq!(config, Config { + profile: Profile::const_new("debug"), + limits: Limits::default() + .limit("stream", 50.kilobytes()) + .limit("forms", 2.kilobytes()) + .limit("file", 100.kilobytes()), + ..Config::default() + }); + + Ok(()) + }); +} + +#[test] +fn test_env_vars_merge() { + use crate::config::{Ident, ShutdownConfig}; + + figment::Jail::expect_with(|jail| { + jail.set_env("ROCKET_KEEP_ALIVE", 9999); + let config = Config::from(Config::figment()); + assert_eq!(config, Config { + keep_alive: 9999, + ..Config::default() + }); + + jail.set_env("ROCKET_SHUTDOWN", r#"{grace=7}"#); + let first_figment = Config::figment(); + jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=10}"#); + let prev_figment = Config::figment().join(&first_figment); + let config = Config::from(&prev_figment); + assert_eq!(config, Config { + keep_alive: 9999, + shutdown: ShutdownConfig { grace: 7, mercy: 10, ..Default::default() }, + ..Config::default() + }); + + jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=20}"#); + let config = Config::from(Config::figment().join(&prev_figment)); + assert_eq!(config, Config { + keep_alive: 9999, + shutdown: ShutdownConfig { grace: 7, mercy: 20, ..Default::default() }, + ..Config::default() + }); + + jail.set_env("ROCKET_LIMITS", r#"{stream=100kiB}"#); + let config = Config::from(Config::figment().join(&prev_figment)); + assert_eq!(config, Config { + keep_alive: 9999, + shutdown: ShutdownConfig { grace: 7, mercy: 20, ..Default::default() }, + limits: Limits::default().limit("stream", 100.kibibytes()), + ..Config::default() + }); + + jail.set_env("ROCKET_IDENT", false); + let config = Config::from(Config::figment().join(&prev_figment)); + assert_eq!(config, Config { + keep_alive: 9999, + shutdown: ShutdownConfig { grace: 7, mercy: 20, ..Default::default() }, + limits: Limits::default().limit("stream", 100.kibibytes()), + ident: Ident::none(), + ..Config::default() + }); + + Ok(()) + }); +} + +#[test] +fn test_precedence() { + figment::Jail::expect_with(|jail| { + jail.create_file("Rocket.toml", r#" + [global.limits] + forms = "1mib" + stream = "50kb" + file = "100kb" + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config, Config { + limits: Limits::default() + .limit("forms", 1.mebibytes()) + .limit("stream", 50.kilobytes()) + .limit("file", 100.kilobytes()), + ..Config::default() + }); + + jail.set_env("ROCKET_LIMITS", r#"{stream=3MiB,capture=2MiB}"#); + let config = Config::from(Config::figment()); + assert_eq!(config, Config { + limits: Limits::default() + .limit("file", 100.kilobytes()) + .limit("forms", 1.mebibytes()) + .limit("stream", 3.mebibytes()) + .limit("capture", 2.mebibytes()), + ..Config::default() + }); + + jail.set_env("ROCKET_PROFILE", "foo"); + let val: Result = Config::figment().extract_inner("profile"); + assert!(val.is_err()); + + Ok(()) + }); +} + +#[test] +#[cfg(feature = "secrets")] +#[should_panic] +fn test_err_on_non_debug_and_no_secret_key() { + figment::Jail::expect_with(|jail| { + jail.set_env("ROCKET_PROFILE", "release"); + let rocket = crate::custom(Config::figment()); + let _result = crate::local::blocking::Client::untracked(rocket); + Ok(()) + }); +} + +#[test] +#[cfg(feature = "secrets")] +#[should_panic] +fn test_err_on_non_debug2_and_no_secret_key() { + figment::Jail::expect_with(|jail| { + jail.set_env("ROCKET_PROFILE", "boop"); + let rocket = crate::custom(Config::figment()); + let _result = crate::local::blocking::Client::tracked(rocket); + Ok(()) + }); +} + +#[test] +fn test_no_err_on_debug_and_no_secret_key() { + figment::Jail::expect_with(|jail| { + jail.set_env("ROCKET_PROFILE", "debug"); + let figment = Config::figment(); + assert!(crate::local::blocking::Client::untracked(crate::custom(&figment)).is_ok()); + crate::async_main(async { + let rocket = crate::custom(&figment); + assert!(crate::local::asynchronous::Client::tracked(rocket).await.is_ok()); + }); + + Ok(()) + }); +} + +#[test] +fn test_no_err_on_release_and_custom_secret_key() { + figment::Jail::expect_with(|jail| { + jail.set_env("ROCKET_PROFILE", "release"); + let key = "Bx4Gb+aSIfuoEyMHD4DvNs92+wmzfQK98qc6MiwyPY4="; + let figment = Config::figment().merge(("secret_key", key)); + + assert!(crate::local::blocking::Client::tracked(crate::custom(&figment)).is_ok()); + crate::async_main(async { + let rocket = crate::custom(&figment); + assert!(crate::local::asynchronous::Client::untracked(rocket).await.is_ok()); + }); + + Ok(()) + }); +} diff --git a/core/lib/src/data/data_stream.rs b/core/lib/src/data/data_stream.rs index 77d03328..955c98f6 100644 --- a/core/lib/src/data/data_stream.rs +++ b/core/lib/src/data/data_stream.rs @@ -68,7 +68,9 @@ pub type RawReader<'r> = StreamReader, Bytes>; /// Raw underlying data stream. pub enum RawStream<'r> { Empty, - Body(&'r mut HyperBody), + Body(HyperBody), + #[cfg(feature = "http3-preview")] + H3Body(crate::listener::Cancellable), Multipart(multer::Field<'r>), } @@ -343,7 +345,9 @@ impl Stream for RawStream<'_> { .poll_frame(cx) .map_ok(|frame| frame.into_data().unwrap_or_else(|_| Bytes::new())) .map_err(io::Error::other) - } + }, + #[cfg(feature = "http3-preview")] + RawStream::H3Body(stream) => Pin::new(stream).poll_next(cx), RawStream::Multipart(s) => Pin::new(s).poll_next(cx).map_err(io::Error::other), RawStream::Empty => Poll::Ready(None), } @@ -356,6 +360,8 @@ impl Stream for RawStream<'_> { let (lower, upper) = (hint.lower(), hint.upper()); (lower as usize, upper.map(|x| x as usize)) }, + #[cfg(feature = "http3-preview")] + RawStream::H3Body(_) => (0, Some(0)), RawStream::Multipart(mp) => mp.size_hint(), RawStream::Empty => (0, Some(0)), } @@ -367,17 +373,26 @@ impl std::fmt::Display for RawStream<'_> { match self { RawStream::Empty => f.write_str("empty stream"), RawStream::Body(_) => f.write_str("request body"), + #[cfg(feature = "http3-preview")] + RawStream::H3Body(_) => f.write_str("http3 quic stream"), RawStream::Multipart(_) => f.write_str("multipart form field"), } } } -impl<'r> From<&'r mut HyperBody> for RawStream<'r> { - fn from(value: &'r mut HyperBody) -> Self { +impl<'r> From for RawStream<'r> { + fn from(value: HyperBody) -> Self { Self::Body(value) } } +#[cfg(feature = "http3-preview")] +impl<'r> From> for RawStream<'r> { + fn from(value: crate::listener::Cancellable) -> Self { + Self::H3Body(value) + } +} + impl<'r> From> for RawStream<'r> { fn from(value: multer::Field<'r>) -> Self { Self::Multipart(value) diff --git a/core/lib/src/data/mod.rs b/core/lib/src/data/mod.rs index e3eebdd2..f7c879dc 100644 --- a/core/lib/src/data/mod.rs +++ b/core/lib/src/data/mod.rs @@ -18,3 +18,5 @@ pub use self::capped::{N, Capped}; pub use self::io_stream::{IoHandler, IoStream}; pub use ubyte::{ByteUnit, ToByteUnit}; pub use self::transform::{Transform, TransformBuf}; + +pub(crate) use self::data_stream::RawStream; diff --git a/core/lib/src/erased.rs b/core/lib/src/erased.rs index 7b62522c..966ff2dd 100644 --- a/core/lib/src/erased.rs +++ b/core/lib/src/erased.rs @@ -6,19 +6,18 @@ use std::task::{Poll, Context}; use futures::future::BoxFuture; use http::request::Parts; -use hyper::body::Incoming; use tokio::io::{AsyncRead, ReadBuf}; -use crate::data::{Data, IoHandler}; +use crate::data::{Data, IoHandler, RawStream}; use crate::{Request, Response, Rocket, Orbit}; // TODO: Magic with trait async fn to get rid of the box pin. // TODO: Write safety proofs. macro_rules! static_assert_covariance { - ($T:tt) => ( + ($($T:tt)*) => ( const _: () = { - fn _assert_covariance<'x: 'y, 'y>(x: &'y $T<'x>) -> &'y $T<'y> { x } + fn _assert_covariance<'x: 'y, 'y>(x: &'y $($T)*<'x>) -> &'y $($T)*<'y> { x } }; ) } @@ -40,7 +39,6 @@ pub struct ErasedResponse { // XXX: SAFETY: This (dependent) field must come first due to drop order! response: Response<'static>, _request: Arc, - _incoming: Box, } impl Drop for ErasedResponse { @@ -79,10 +77,9 @@ impl ErasedRequest { ErasedRequest { _rocket: rocket, _parts: parts, request, } } - pub async fn into_response( + pub async fn into_response( self, - incoming: Incoming, - data_builder: impl for<'r> FnOnce(&'r mut Incoming) -> Data<'r>, + raw_stream: D, preprocess: impl for<'r, 'x> FnOnce( &'r Rocket, &'r mut Request<'x>, @@ -94,14 +91,11 @@ impl ErasedRequest { &'r Request<'r>, Data<'r> ) -> BoxFuture<'r, Response<'r>>, - ) -> ErasedResponse { - let mut incoming = Box::new(incoming); - let mut data: Data<'_> = { - let incoming: &mut Incoming = &mut *incoming; - let incoming: &'static mut Incoming = unsafe { transmute(incoming) }; - data_builder(incoming) - }; - + ) -> ErasedResponse + where T: Send + Sync + 'static, + D: for<'r> Into> + { + let mut data: Data<'_> = Data::from(raw_stream); let mut parent = Arc::new(self); let token: T = { let parent: &mut ErasedRequest = Arc::get_mut(&mut parent).unwrap(); @@ -122,7 +116,6 @@ impl ErasedRequest { ErasedResponse { _request: parent, - _incoming: incoming, response: response, } } diff --git a/core/lib/src/error.rs b/core/lib/src/error.rs index 21753b1f..b072dd77 100644 --- a/core/lib/src/error.rs +++ b/core/lib/src/error.rs @@ -7,7 +7,8 @@ use std::error::Error as StdError; use yansi::Paint; use figment::Profile; -use crate::{Rocket, Orbit}; +use crate::listener::Endpoint; +use crate::{Ignite, Orbit, Rocket}; /// An error that occurs during launch. /// @@ -74,8 +75,8 @@ pub struct Error { #[derive(Debug)] #[non_exhaustive] pub enum ErrorKind { - /// Binding to the network interface failed. - Bind(Box), + /// Binding to the network interface at `.0` failed with error `.1`. + Bind(Option, Box), /// An I/O error occurred during launch. Io(io::Error), /// A valid [`Config`](crate::Config) could not be extracted from the @@ -89,6 +90,11 @@ pub enum ErrorKind { SentinelAborts(Vec), /// The configuration profile is not debug but no secret key is configured. InsecureSecretKey(Profile), + /// Liftoff failed. Contains the Rocket instance that failed to shutdown. + Liftoff( + Result, Arc>>, + Box + ), /// Shutdown failed. Contains the Rocket instance that failed to shutdown. Shutdown(Arc>), } @@ -171,8 +177,12 @@ impl Error { pub fn pretty_print(&self) -> &'static str { self.mark_handled(); match self.kind() { - ErrorKind::Bind(ref e) => { - error!("Binding to the network interface failed."); + ErrorKind::Bind(ref a, ref e) => { + match a { + Some(a) => error!("Binding to {} failed.", a.primary().underline()), + None => error!("Binding to network interface failed."), + } + info_!("{}", e); "aborting due to bind error" } @@ -225,6 +235,11 @@ impl Error { "aborting due to sentinel-triggered abort(s)" } + ErrorKind::Liftoff(_, error) => { + error!("Rocket liftoff failed due to panicking liftoff fairing(s)."); + error_!("{error}"); + "aborting due to failed liftoff" + } ErrorKind::Shutdown(_) => { error!("Rocket failed to shutdown gracefully."); "aborting due to failed shutdown" @@ -239,13 +254,14 @@ impl fmt::Display for ErrorKind { #[inline] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - ErrorKind::Bind(e) => write!(f, "binding failed: {e}"), + ErrorKind::Bind(_, e) => write!(f, "binding failed: {e}"), ErrorKind::Io(e) => write!(f, "I/O error: {e}"), ErrorKind::Collisions(_) => "collisions detected".fmt(f), ErrorKind::FailedFairings(_) => "launch fairing(s) failed".fmt(f), ErrorKind::InsecureSecretKey(_) => "insecure secret key config".fmt(f), ErrorKind::Config(_) => "failed to extract configuration".fmt(f), ErrorKind::SentinelAborts(_) => "sentinel(s) aborted".fmt(f), + ErrorKind::Liftoff(_, _) => "liftoff failed".fmt(f), ErrorKind::Shutdown(_) => "shutdown failed".fmt(f), } } @@ -293,40 +309,45 @@ impl fmt::Display for Empty { impl StdError for Empty { } /// Log an error that occurs during request processing -pub(crate) fn log_server_error(error: &Box) { +#[track_caller] +pub(crate) fn log_server_error(error: &(dyn StdError + 'static)) { struct ServerError<'a>(&'a (dyn StdError + 'static)); impl fmt::Display for ServerError<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let error = &self.0; if let Some(e) = error.downcast_ref::() { - write!(f, "request processing failed: {e}")?; + write!(f, "request failed: {e}")?; } else if let Some(e) = error.downcast_ref::() { - write!(f, "connection I/O error: ")?; + write!(f, "connection error: ")?; match e.kind() { io::ErrorKind::NotConnected => write!(f, "remote disconnected")?, io::ErrorKind::UnexpectedEof => write!(f, "remote sent early eof")?, io::ErrorKind::ConnectionReset - | io::ErrorKind::ConnectionAborted - | io::ErrorKind::BrokenPipe => write!(f, "terminated by remote")?, + | io::ErrorKind::ConnectionAborted => write!(f, "terminated by remote")?, _ => write!(f, "{e}")?, } } else { write!(f, "http server error: {error}")?; } - if let Some(e) = error.source() { - write!(f, " ({})", ServerError(e))?; - } - Ok(()) } } + let mut error: &(dyn StdError + 'static) = &*error; if error.downcast_ref::().is_some() { - warn!("{}", ServerError(&**error)) + warn!("{}", ServerError(error)); + while let Some(source) = error.source() { + error = source; + warn_!("{}", ServerError(error)); + } } else { - error!("{}", ServerError(&**error)) + error!("{}", ServerError(error)); + while let Some(source) = error.source() { + error = source; + error_!("{}", ServerError(error)); + } } } diff --git a/core/lib/src/fairing/ad_hoc.rs b/core/lib/src/fairing/ad_hoc.rs index a83d6e58..a167c42f 100644 --- a/core/lib/src/fairing/ad_hoc.rs +++ b/core/lib/src/fairing/ad_hoc.rs @@ -187,7 +187,7 @@ impl AdHoc { /// Constructs an `AdHoc` shutdown fairing named `name`. The function `f` /// will be called by Rocket when [shutdown is triggered]. /// - /// [shutdown is triggered]: crate::config::Shutdown#triggers + /// [shutdown is triggered]: crate::config::ShutdownConfig#triggers /// /// # Example /// diff --git a/core/lib/src/fairing/mod.rs b/core/lib/src/fairing/mod.rs index fec2b33f..ad9aaca4 100644 --- a/core/lib/src/fairing/mod.rs +++ b/core/lib/src/fairing/mod.rs @@ -191,8 +191,8 @@ pub type Result, E = Rocket> = std::result::Result CookieJar<'a> { ops: Mutex::new(Vec::new()), state: CookieState { // This is updated dynamically when headers are received. - secure: rocket.endpoint().is_tls(), + secure: rocket.endpoints().all(|e| e.is_tls()), config: rocket.config(), } } diff --git a/core/lib/src/lib.rs b/core/lib/src/lib.rs index 0f6c3871..069f83d0 100644 --- a/core/lib/src/lib.rs +++ b/core/lib/src/lib.rs @@ -61,15 +61,17 @@ //! To avoid compiling unused dependencies, Rocket gates certain features. With //! the exception of `http2`, all are disabled by default: //! -//! | Feature | Description | -//! |-----------|---------------------------------------------------------| -//! | `secrets` | Support for authenticated, encrypted [private cookies]. | -//! | `tls` | Support for [TLS] encrypted connections. | -//! | `mtls` | Support for verified clients via [mutual TLS]. | -//! | `http2` | Support for HTTP/2 (enabled by default). | -//! | `json` | Support for [JSON (de)serialization]. | -//! | `msgpack` | Support for [MessagePack (de)serialization]. | -//! | `uuid` | Support for [UUID value parsing and (de)serialization]. | +//! | Feature | Description | +//! |-----------------|---------------------------------------------------------| +//! | `secrets` | Support for authenticated, encrypted [private cookies]. | +//! | `tls` | Support for [TLS] encrypted connections. | +//! | `mtls` | Support for verified clients via [mutual TLS]. | +//! | `http2` | Support for HTTP/2 (enabled by default). | +//! | `json` | Support for [JSON (de)serialization]. | +//! | `msgpack` | Support for [MessagePack (de)serialization]. | +//! | `uuid` | Support for [UUID value parsing and (de)serialization]. | +//! | `tokio-macros` | Enables the `macros` feature in the exported `tokio` | +//! | `http3-preview` | Experimental preview support for [HTTP/3]. | //! //! Disabled features can be selectively enabled in `Cargo.toml`: //! @@ -91,6 +93,7 @@ //! [private cookies]: https://rocket.rs/master/guide/requests/#private-cookies //! [TLS]: https://rocket.rs/master/guide/configuration/#tls //! [mutual TLS]: crate::mtls +//! [HTTP/3]: crate::listener::quic //! //! ## Configuration //! @@ -143,6 +146,7 @@ pub mod shield; pub mod fs; pub mod http; pub mod listener; +pub mod shutdown; #[cfg(feature = "tls")] #[cfg_attr(nightly, doc(cfg(feature = "tls")))] pub mod tls; @@ -151,7 +155,6 @@ pub mod tls; pub mod mtls; mod util; -mod shutdown; mod server; mod lifecycle; mod state; diff --git a/core/lib/src/listener/bindable.rs b/core/lib/src/listener/bindable.rs index 67021382..62eb9e7b 100644 --- a/core/lib/src/listener/bindable.rs +++ b/core/lib/src/listener/bindable.rs @@ -1,6 +1,7 @@ +use std::io; use futures::TryFutureExt; -use crate::listener::Listener; +use crate::listener::{Listener, Endpoint}; pub trait Bindable: Sized { type Listener: Listener + 'static; @@ -8,6 +9,9 @@ pub trait Bindable: Sized { type Error: std::error::Error + Send + 'static; async fn bind(self) -> Result; + + /// The endpoint that `self` binds on. + fn candidate_endpoint(&self) -> io::Result; } impl Bindable for L { @@ -18,6 +22,10 @@ impl Bindable for L { async fn bind(self) -> Result { Ok(self) } + + fn candidate_endpoint(&self) -> io::Result { + L::endpoint(self) + } } impl Bindable for either::Either { @@ -37,4 +45,8 @@ impl Bindable for either::Either { .await, } } + + fn candidate_endpoint(&self) -> io::Result { + either::for_both!(self, a => a.candidate_endpoint()) + } } diff --git a/core/lib/src/listener/bounced.rs b/core/lib/src/listener/bounced.rs index c8e4203b..ad264192 100644 --- a/core/lib/src/listener/bounced.rs +++ b/core/lib/src/listener/bounced.rs @@ -52,7 +52,7 @@ impl Listener for Bounced { self.listener.connect(accept).await } - fn socket_addr(&self) -> io::Result { - self.listener.socket_addr() + fn endpoint(&self) -> io::Result { + self.listener.endpoint() } } diff --git a/core/lib/src/listener/cancellable.rs b/core/lib/src/listener/cancellable.rs index fbabfb2c..7ec9b81b 100644 --- a/core/lib/src/listener/cancellable.rs +++ b/core/lib/src/listener/cancellable.rs @@ -1,178 +1,79 @@ use std::io; -use std::time::Duration; use std::task::{Poll, Context}; use std::pin::Pin; -use tokio::time::{sleep, Sleep}; +use futures::Stream; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use futures::{StreamExt, future::{select, Either, Fuse, Future, FutureExt}}; +use futures::future::FutureExt; use pin_project_lite::pin_project; -use crate::{config, Shutdown}; -use crate::listener::{Listener, Connection, Certificates, Bounced, Endpoint}; - -// Rocket wraps all connections in a `CancellableIo` struct, an internal -// structure that gracefully closes I/O when it receives a signal. That signal -// is the `shutdown` future. When the future resolves, `CancellableIo` begins to -// terminate in grace, mercy, and finally force close phases. Since all -// connections are wrapped in `CancellableIo`, this eventually ends all I/O. -// -// At that point, unless a user spawned an infinite, stand-alone task that isn't -// monitoring `Shutdown`, all tasks should resolve. This means that all -// instances of the shared `Arc` are dropped and we can return the owned -// instance of `Rocket`. -// -// Unfortunately, the Hyper `server` future resolves as soon as it has finished -// processing requests without respect for ongoing responses. That is, `server` -// resolves even when there are running tasks that are generating a response. -// So, `server` resolving implies little to nothing about the state of -// connections. As a result, we depend on the timing of grace + mercy + some -// buffer to determine when all connections should be closed, thus all tasks -// should be complete, thus all references to `Arc` should be dropped -// and we can get a unique reference. -pin_project! { - pub struct CancellableListener { - pub trigger: F, - #[pin] - pub listener: L, - pub grace: Duration, - pub mercy: Duration, - } -} +use crate::shutdown::Stages; pin_project! { /// I/O that can be cancelled when a future `F` resolves. #[must_use = "futures do nothing unless polled"] - pub struct CancellableIo { + pub struct Cancellable { #[pin] io: Option, - #[pin] - trigger: Fuse, + stages: Stages, state: State, - grace: Duration, - mercy: Duration, } } +#[derive(Debug)] enum State { - /// I/O has not been cancelled. Proceed as normal. + /// I/O has not been cancelled. Proceed as normal until `Shutdown`. Active, - /// I/O has been cancelled. See if we can finish before the timer expires. - Grace(Pin>), - /// Grace period elapsed. Shutdown the connection, waiting for the timer - /// until we force close. - Mercy(Pin>), + /// I/O has been cancelled. Try to finish before `Shutdown`. + Grace, + /// Grace has elapsed. Shutdown connections. After `Shutdown`, force close. + Mercy, } pub trait CancellableExt: Sized { - fn cancellable( - self, - trigger: Shutdown, - config: &config::Shutdown - ) -> CancellableListener { - if let Some(mut stream) = config.signal_stream() { - let trigger = trigger.clone(); - tokio::spawn(async move { - while let Some(sig) = stream.next().await { - if trigger.0.tripped() { - warn!("Received {}. Shutdown already in progress.", sig); - } else { - warn!("Received {}. Requesting shutdown.", sig); - } - - trigger.0.trip(); - } - }); - }; - - CancellableListener { - trigger, - listener: self, - grace: config.grace(), - mercy: config.mercy(), + fn cancellable(self, stages: Stages) -> Cancellable { + Cancellable { + io: Some(self), + state: State::Active, + stages, } } } -impl CancellableExt for L { } +impl CancellableExt for T { } fn time_out() -> io::Error { - io::Error::new(io::ErrorKind::TimedOut, "Shutdown grace timed out") + io::Error::new(io::ErrorKind::TimedOut, "shutdown grace period elapsed") } fn gone() -> io::Error { - io::Error::new(io::ErrorKind::BrokenPipe, "IO driver has terminated") + io::Error::new(io::ErrorKind::BrokenPipe, "I/O driver terminated") } -impl CancellableListener> - where L: Listener + Sync, - F: Future + Unpin + Clone + Send + Sync + 'static -{ - pub async fn accept_next(&self) -> Option<::Accept> { - let next = std::pin::pin!(self.listener.accept_next()); - match select(next, self.trigger.clone()).await { - Either::Left((next, _)) => Some(next), - Either::Right(_) => None, - } - } -} - -impl CancellableListener - where L: Listener + Sync, - F: Future + Clone + Send + Sync + 'static -{ - fn io(&self, conn: C) -> CancellableIo { - CancellableIo { - io: Some(conn), - trigger: self.trigger.clone().fuse(), - state: State::Active, - grace: self.grace, - mercy: self.mercy, - } - } -} - -impl Listener for CancellableListener - where L: Listener + Sync, - F: Future + Clone + Send + Sync + Unpin + 'static -{ - type Accept = L::Accept; - - type Connection = CancellableIo; - - async fn accept(&self) -> io::Result { - let accept = std::pin::pin!(self.listener.accept()); - match select(accept, self.trigger.clone()).await { - Either::Left((result, _)) => result, - Either::Right(_) => Err(gone()), - } - } - - async fn connect(&self, accept: Self::Accept) -> io::Result { - let conn = std::pin::pin!(self.listener.connect(accept)); - match select(conn, self.trigger.clone()).await { - Either::Left((conn, _)) => Ok(self.io(conn?)), - Either::Right(_) => Err(gone()), - } - } - - fn socket_addr(&self) -> io::Result { - self.listener.socket_addr() - } -} - -impl CancellableIo { - fn inner(&self) -> Option<&I> { +impl Cancellable { + pub fn inner(&self) -> Option<&I> { self.io.as_ref() } +} +pub trait AsyncCancel { + fn poll_cancel(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; +} + +impl AsyncCancel for T { + fn poll_cancel(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ::poll_shutdown(self, cx) + } +} + +impl Cancellable { /// Run `do_io` while connection processing should continue. - fn poll_trigger_then( + pub fn poll_with( mut self: Pin<&mut Self>, cx: &mut Context<'_>, do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll>, ) -> Poll> { - let mut me = self.as_mut().project(); + let me = self.as_mut().project(); let io = match me.io.as_pin_mut() { Some(io) => io, None => return Poll::Ready(Err(gone())), @@ -181,29 +82,29 @@ impl CancellableIo { loop { match me.state { State::Active => { - if me.trigger.as_mut().poll(cx).is_ready() { - *me.state = State::Grace(Box::pin(sleep(*me.grace))); + if me.stages.start.poll_unpin(cx).is_ready() { + *me.state = State::Grace; } else { return do_io(io, cx); } } - State::Grace(timer) => { - if timer.as_mut().poll(cx).is_ready() { - *me.state = State::Mercy(Box::pin(sleep(*me.mercy))); + State::Grace => { + if me.stages.grace.poll_unpin(cx).is_ready() { + *me.state = State::Mercy; } else { return do_io(io, cx); } } - State::Mercy(timer) => { - if timer.as_mut().poll(cx).is_ready() { + State::Mercy => { + if me.stages.mercy.poll_unpin(cx).is_ready() { self.project().io.set(None); return Poll::Ready(Err(time_out())); } else { - let result = futures::ready!(io.poll_shutdown(cx)); + let result = futures::ready!(io.poll_cancel(cx)); self.project().io.set(None); return match result { + Ok(()) => Poll::Ready(Err(gone())), Err(e) => Poll::Ready(Err(e)), - Ok(()) => Poll::Ready(Err(gone())) }; } }, @@ -212,45 +113,45 @@ impl CancellableIo { } } -impl AsyncRead for CancellableIo { +impl AsyncRead for Cancellable { fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_read(cx, buf)) + self.poll_with(cx, |io, cx| io.poll_read(cx, buf)) } } -impl AsyncWrite for CancellableIo { +impl AsyncWrite for Cancellable { fn poll_write( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write(cx, buf)) + self.poll_with(cx, |io, cx| io.poll_write(cx, buf)) } fn poll_flush( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_flush(cx)) + self.poll_with(cx, |io, cx| io.poll_flush(cx)) } fn poll_shutdown( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_shutdown(cx)) + self.poll_with(cx, |io, cx| io.poll_shutdown(cx)) } fn poll_write_vectored( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write_vectored(cx, bufs)) + self.poll_with(cx, |io, cx| io.poll_write_vectored(cx, bufs)) } fn is_write_vectored(&self) -> bool { @@ -258,16 +159,16 @@ impl AsyncWrite for CancellableIo { } } -impl Connection for CancellableIo - where F: Unpin + Send + 'static -{ - fn peer_address(&self) -> io::Result { - self.inner() - .ok_or_else(|| gone()) - .and_then(|io| io.peer_address()) - } +impl> + AsyncCancel> Stream for Cancellable { + type Item = I::Item; - fn peer_certificates(&self) -> Option> { - self.inner().and_then(|io| io.peer_certificates()) + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use futures::ready; + + match ready!(self.poll_with(cx, |io, cx| io.poll_next(cx).map(Ok))) { + Ok(Some(v)) => Poll::Ready(Some(v)), + Ok(None) => Poll::Ready(None), + Err(e) => Poll::Ready(Some(Err(e))), + } } } diff --git a/core/lib/src/listener/connection.rs b/core/lib/src/listener/connection.rs index 68541109..49d17789 100644 --- a/core/lib/src/listener/connection.rs +++ b/core/lib/src/listener/connection.rs @@ -2,7 +2,6 @@ use std::io; use std::borrow::Cow; use tokio_util::either::Either; -use tokio::io::{AsyncRead, AsyncWrite}; use super::Endpoint; @@ -10,8 +9,8 @@ use super::Endpoint; #[derive(Clone)] pub struct Certificates<'r>(Cow<'r, [der::CertificateDer<'r>]>); -pub trait Connection: AsyncRead + AsyncWrite + Send + Unpin { - fn peer_address(&self) -> io::Result; +pub trait Connection: Send + Unpin { + fn endpoint(&self) -> io::Result; /// DER-encoded X.509 certificate chain presented by the client, if any. /// @@ -21,21 +20,21 @@ pub trait Connection: AsyncRead + AsyncWrite + Send + Unpin { /// /// Defaults to an empty vector to indicate that no certificates were /// presented. - fn peer_certificates(&self) -> Option> { None } + fn certificates(&self) -> Option> { None } } impl Connection for Either { - fn peer_address(&self) -> io::Result { + fn endpoint(&self) -> io::Result { match self { - Either::Left(c) => c.peer_address(), - Either::Right(c) => c.peer_address(), + Either::Left(c) => c.endpoint(), + Either::Right(c) => c.endpoint(), } } - fn peer_certificates(&self) -> Option> { + fn certificates(&self) -> Option> { match self { - Either::Left(c) => c.peer_certificates(), - Either::Right(c) => c.peer_certificates(), + Either::Left(c) => c.certificates(), + Either::Right(c) => c.certificates(), } } } diff --git a/core/lib/src/listener/default.rs b/core/lib/src/listener/default.rs index 32f4a650..5732abe0 100644 --- a/core/lib/src/listener/default.rs +++ b/core/lib/src/listener/default.rs @@ -32,15 +32,15 @@ impl DefaultListener { Ok(BaseBindable::Right(uds)) }, #[cfg(not(unix))] - Endpoint::Unix(_) => { + e@Endpoint::Unix(_) => { let msg = "Unix domain sockets unavailable on non-unix platforms."; let boxed = Box::::from(msg); - Err(Error::new(ErrorKind::Bind(boxed))) + Err(Error::new(ErrorKind::Bind(Some(e.clone()), boxed))) }, other => { let msg = format!("unsupported default listener address: {other}"); let boxed = Box::::from(msg); - Err(Error::new(ErrorKind::Bind(boxed))) + Err(Error::new(ErrorKind::Bind(Some(other.clone()), boxed))) } } } diff --git a/core/lib/src/listener/endpoint.rs b/core/lib/src/listener/endpoint.rs index 26640d1d..e9c45aaa 100644 --- a/core/lib/src/listener/endpoint.rs +++ b/core/lib/src/listener/endpoint.rs @@ -1,7 +1,7 @@ use std::fmt; -use std::path::{Path, PathBuf}; use std::any::Any; -use std::net::{SocketAddr as TcpAddr, Ipv4Addr, AddrParseError}; +use std::net::{self, AddrParseError, IpAddr, Ipv4Addr}; +use std::path::{Path, PathBuf}; use std::str::FromStr; use std::sync::Arc; @@ -9,23 +9,23 @@ use serde::de; use crate::http::uncased::AsUncased; +#[cfg(feature = "tls")] type TlsInfo = Option>; +#[cfg(not(feature = "tls"))] type TlsInfo = Option<()>; + pub trait EndpointAddr: fmt::Display + fmt::Debug + Sync + Send + Any { } impl EndpointAddr for T {} -#[cfg(not(feature = "tls"))] type TlsInfo = Option<()>; -#[cfg(feature = "tls")] type TlsInfo = Option; - /// # Conversions /// /// * [`&str`] - parse with [`FromStr`] -/// * [`tokio::net::unix::SocketAddr`] - must be path: [`ListenerAddr::Unix`] -/// * [`std::net::SocketAddr`] - infallibly as [ListenerAddr::Tcp] -/// * [`PathBuf`] - infallibly as [`ListenerAddr::Unix`] -// TODO: Rename to something better. `Endpoint`? -#[derive(Debug)] +/// * [`tokio::net::unix::SocketAddr`] - must be path: [`Endpoint::Unix`] +/// * [`PathBuf`] - infallibly as [`Endpoint::Unix`] +#[derive(Debug, Clone)] +#[non_exhaustive] pub enum Endpoint { - Tcp(TcpAddr), + Tcp(net::SocketAddr), + Quic(net::SocketAddr), Unix(PathBuf), Tls(Arc, TlsInfo), Custom(Arc), @@ -36,9 +36,45 @@ impl Endpoint { Endpoint::Custom(Arc::new(value)) } - pub fn tcp(&self) -> Option { + pub fn tcp(&self) -> Option { match self { Endpoint::Tcp(addr) => Some(*addr), + Endpoint::Tls(addr, _) => addr.tcp(), + _ => None, + } + } + + pub fn quic(&self) -> Option { + match self { + Endpoint::Quic(addr) => Some(*addr), + Endpoint::Tls(addr, _) => addr.tcp(), + _ => None, + } + } + + pub fn socket_addr(&self) -> Option { + match self { + Endpoint::Quic(addr) => Some(*addr), + Endpoint::Tcp(addr) => Some(*addr), + Endpoint::Tls(inner, _) => inner.socket_addr(), + _ => None, + } + } + + pub fn ip(&self) -> Option { + match self { + Endpoint::Quic(addr) => Some(addr.ip()), + Endpoint::Tcp(addr) => Some(addr.ip()), + Endpoint::Tls(inner, _) => inner.ip(), + _ => None, + } + } + + pub fn port(&self) -> Option { + match self { + Endpoint::Quic(addr) => Some(addr.port()), + Endpoint::Tcp(addr) => Some(addr.port()), + Endpoint::Tls(inner, _) => inner.port(), _ => None, } } @@ -46,6 +82,7 @@ impl Endpoint { pub fn unix(&self) -> Option<&Path> { match self { Endpoint::Unix(addr) => Some(addr), + Endpoint::Tls(addr, _) => addr.unix(), _ => None, } } @@ -76,6 +113,7 @@ impl Endpoint { pub fn downcast(&self) -> Option<&T> { match self { Endpoint::Tcp(addr) => (&*addr as &dyn Any).downcast_ref(), + Endpoint::Quic(addr) => (&*addr as &dyn Any).downcast_ref(), Endpoint::Unix(addr) => (&*addr as &dyn Any).downcast_ref(), Endpoint::Custom(addr) => (&*addr as &dyn Any).downcast_ref(), Endpoint::Tls(inner, ..) => inner.downcast(), @@ -86,6 +124,10 @@ impl Endpoint { self.tcp().is_some() } + pub fn is_quic(&self) -> bool { + self.quic().is_some() + } + pub fn is_unix(&self) -> bool { self.unix().is_some() } @@ -95,12 +137,12 @@ impl Endpoint { } #[cfg(feature = "tls")] - pub fn with_tls(self, config: crate::tls::TlsConfig) -> Endpoint { + pub fn with_tls(self, tls: &crate::tls::TlsConfig) -> Endpoint { if self.is_tls() { return self; } - Self::Tls(Arc::new(self), Some(config)) + Self::Tls(Arc::new(self), Some(Box::new(tls.clone()))) } pub fn assume_tls(self) -> Endpoint { @@ -117,39 +159,27 @@ impl fmt::Display for Endpoint { use Endpoint::*; match self { - Tcp(addr) => write!(f, "http://{addr}"), + Tcp(addr) | Quic(addr) => write!(f, "http://{addr}"), Unix(addr) => write!(f, "unix:{}", addr.display()), Custom(inner) => inner.fmt(f), - Tls(inner, c) => match (&**inner, c.as_ref()) { + Tls(inner, _c) => { + match (inner.tcp(), inner.quic()) { + (Some(addr), _) => write!(f, "https://{addr} (TCP")?, + (_, Some(addr)) => write!(f, "https://{addr} (QUIC")?, + (None, None) => write!(f, "{inner} (TLS")?, + } + #[cfg(feature = "mtls")] - (Tcp(i), Some(c)) if c.mutual().is_some() => write!(f, "https://{i} (TLS + MTLS)"), - (Tcp(i), _) => write!(f, "https://{i} (TLS)"), - #[cfg(feature = "mtls")] - (i, Some(c)) if c.mutual().is_some() => write!(f, "{i} (TLS + MTLS)"), - (inner, _) => write!(f, "{inner} (TLS)"), - }, + if _c.as_ref().and_then(|c| c.mutual()).is_some() { + write!(f, " + mTLS")?; + } + + write!(f, ")") + } } } } -impl From for Endpoint { - fn from(value: std::net::SocketAddr) -> Self { - Self::Tcp(value) - } -} - -impl From for Endpoint { - fn from(value: std::net::SocketAddrV4) -> Self { - Self::Tcp(value.into()) - } -} - -impl From for Endpoint { - fn from(value: std::net::SocketAddrV6) -> Self { - Self::Tcp(value.into()) - } -} - impl From for Endpoint { fn from(value: PathBuf) -> Self { Self::Unix(value) @@ -177,36 +207,38 @@ impl TryFrom<&str> for Endpoint { impl Default for Endpoint { fn default() -> Self { - Endpoint::Tcp(TcpAddr::new(Ipv4Addr::LOCALHOST.into(), 8000)) + Endpoint::Tcp(net::SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8000)) } } -/// Parses an address into a `ListenerAddr`. +/// Parses an address into a `Endpoint`. /// /// The syntax is: /// /// ```text -/// listener_addr = 'tcp' ':' tcp_addr | 'unix' ':' unix_addr | tcp_addr -/// tcp_addr := IP_ADDR | SOCKET_ADDR -/// unix_addr := PATH +/// endpoint = 'tcp' ':' socket | 'quic' ':' socket | 'unix' ':' path | socket +/// socket := IP_ADDR | SOCKET_ADDR +/// path := PATH /// /// IP_ADDR := `std::net::IpAddr` string as defined by Rust /// SOCKET_ADDR := `std::net::SocketAddr` string as defined by Rust /// PATH := `PathBuf` (any UTF-8) string as defined by Rust /// ``` /// -/// If `IP_ADDR` is specified, the port defaults to `8000`. +/// If `IP_ADDR` is specified in socket, port defaults to `8000`. impl FromStr for Endpoint { type Err = AddrParseError; fn from_str(string: &str) -> Result { - fn parse_tcp(string: &str, def_port: u16) -> Result { - string.parse().or_else(|_| string.parse().map(|ip| TcpAddr::new(ip, def_port))) + fn parse_tcp(str: &str, def_port: u16) -> Result { + str.parse().or_else(|_| str.parse().map(|ip| net::SocketAddr::new(ip, def_port))) } if let Some((proto, string)) = string.split_once(':') { if proto.trim().as_uncased() == "tcp" { return parse_tcp(string.trim(), 8000).map(Self::Tcp); + } else if proto.trim().as_uncased() == "quic" { + return parse_tcp(string.trim(), 8000).map(Self::Quic); } else if proto.trim().as_uncased() == "unix" { return Ok(Self::Unix(PathBuf::from(string.trim()))); } @@ -242,6 +274,7 @@ impl PartialEq for Endpoint { fn eq(&self, other: &Self) -> bool { match (self, other) { (Self::Tcp(l0), Self::Tcp(r0)) => l0 == r0, + (Self::Quic(l0), Self::Quic(r0)) => l0 == r0, (Self::Unix(l0), Self::Unix(r0)) => l0 == r0, (Self::Tls(l0, _), Self::Tls(r0, _)) => l0 == r0, (Self::Custom(l0), Self::Custom(r0)) => l0.to_string() == r0.to_string(), @@ -250,24 +283,6 @@ impl PartialEq for Endpoint { } } -impl PartialEq for Endpoint { - fn eq(&self, other: &std::net::SocketAddr) -> bool { - self.tcp() == Some(*other) - } -} - -impl PartialEq for Endpoint { - fn eq(&self, other: &std::net::SocketAddrV4) -> bool { - self.tcp() == Some((*other).into()) - } -} - -impl PartialEq for Endpoint { - fn eq(&self, other: &std::net::SocketAddrV6) -> bool { - self.tcp() == Some((*other).into()) - } -} - impl PartialEq for Endpoint { fn eq(&self, other: &PathBuf) -> bool { self.unix() == Some(other.as_path()) diff --git a/core/lib/src/listener/listener.rs b/core/lib/src/listener/listener.rs index 8bdbc08c..a272b699 100644 --- a/core/lib/src/listener/listener.rs +++ b/core/lib/src/listener/listener.rs @@ -10,12 +10,13 @@ pub trait Listener: Send + Sync { type Connection: Connection; + #[crate::async_bound(Send)] async fn accept(&self) -> io::Result; #[crate::async_bound(Send)] async fn connect(&self, accept: Self::Accept) -> io::Result; - fn socket_addr(&self) -> io::Result; + fn endpoint(&self) -> io::Result; } impl Listener for &L { @@ -31,8 +32,8 @@ impl Listener for &L { ::connect(self, accept).await } - fn socket_addr(&self) -> io::Result { - ::socket_addr(self) + fn endpoint(&self) -> io::Result { + ::endpoint(self) } } @@ -56,10 +57,10 @@ impl Listener for Either { } } - fn socket_addr(&self) -> io::Result { + fn endpoint(&self) -> io::Result { match self { - Either::Left(l) => l.socket_addr(), - Either::Right(l) => l.socket_addr(), + Either::Left(l) => l.endpoint(), + Either::Right(l) => l.endpoint(), } } } diff --git a/core/lib/src/listener/mod.rs b/core/lib/src/listener/mod.rs index 244c36c6..4e0ea0c8 100644 --- a/core/lib/src/listener/mod.rs +++ b/core/lib/src/listener/mod.rs @@ -13,6 +13,8 @@ pub mod unix; #[cfg_attr(nightly, doc(cfg(feature = "tls")))] pub mod tls; pub mod tcp; +#[cfg(feature = "http3-preview")] +pub mod quic; pub use endpoint::*; pub use listener::*; diff --git a/core/lib/src/listener/quic.rs b/core/lib/src/listener/quic.rs new file mode 100644 index 00000000..e43cabc9 --- /dev/null +++ b/core/lib/src/listener/quic.rs @@ -0,0 +1,234 @@ +//! Experimental support for Quic and HTTP/3. +//! +//! To enable Rocket's experimental support for HTTP/3 and Quic, enable the +//! `http3-preview` feature and provide a valid TLS configuration: +//! +//! ```toml +//! // Add the following to your Cargo.toml: +//! [dependencies] +//! rocket = { version = "0.6.0-dev", features = ["http3-preview"] } +//! +//! // In your Rocket.toml or other equivalent config source: +//! [default.tls] +//! certs = "private/rsa_sha256_cert.pem" +//! key = "private/rsa_sha256_key.pem" +//! ``` +//! +//! The launch message confirms that Rocket is serving traffic over Quic in +//! addition to TCP: +//! +//! ```sh +//! > 🚀 Rocket has launched on https://127.0.0.1:8000 (QUIC + mTLS) +//! ``` +//! +//! mTLS is not yet supported via this implementation. + +use std::io; +use std::fmt; +use std::net::SocketAddr; +use std::pin::pin; + +use s2n_quic as quic; +use s2n_quic_h3 as quic_h3; +use quic_h3::h3 as h3; + +use bytes::Bytes; +use futures::Stream; +use tokio::sync::Mutex; +use tokio_stream::StreamExt; + +use crate::tls::TlsConfig; +use crate::listener::{Listener, Connection, Endpoint}; + +type H3Conn = h3::server::Connection; + +pub struct QuicListener { + endpoint: SocketAddr, + listener: Mutex, + tls: TlsConfig, +} + +pub struct H3Stream(H3Conn); + +pub struct H3Connection { + pub handle: quic::connection::Handle, + pub parts: http::request::Parts, + pub tx: QuicTx, + pub rx: QuicRx, +} + +pub struct QuicRx(h3::server::RequestStream); + +pub struct QuicTx(h3::server::RequestStream, Bytes>); + +impl QuicListener { + pub async fn bind(address: SocketAddr, tls: TlsConfig) -> Result { + use quic::provider::tls::rustls::{rustls, DEFAULT_CIPHERSUITES, Server as H3TlsServer}; + + // FIXME: Remove this as soon as `s2n_quic` is on rustls >= 0.22. + let cert_chain = crate::tls::util::load_cert_chain(&mut tls.certs_reader().unwrap()) + .unwrap() + .into_iter() + .map(|v| v.to_vec()) + .map(rustls::Certificate) + .collect::>(); + + let key = crate::tls::util::load_key(&mut tls.key_reader().unwrap()) + .unwrap() + .secret_der() + .to_vec(); + + let mut h3tls = rustls::server::ServerConfig::builder() + .with_cipher_suites(DEFAULT_CIPHERSUITES) + .with_safe_default_kx_groups() + .with_safe_default_protocol_versions() + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))? + .with_client_cert_verifier(rustls::server::NoClientAuth::boxed()) + .with_single_cert(cert_chain, rustls::PrivateKey(key)) + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?; + + h3tls.alpn_protocols = vec![b"h3".to_vec()]; + h3tls.ignore_client_order = tls.prefer_server_cipher_order; + h3tls.session_storage = rustls::server::ServerSessionMemoryCache::new(1024); + h3tls.ticketer = rustls::Ticketer::new() + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS ticketer: {}", e)))?; + + let listener = quic::Server::builder() + .with_tls(H3TlsServer::new(h3tls)) + .unwrap_or_else(|e| match e { }) + .with_io(address)? + .start() + .map_err(io::Error::other)?; + + Ok(QuicListener { + tls, + endpoint: listener.local_addr()?, + listener: Mutex::new(listener), + }) + } +} + +impl Listener for QuicListener { + type Accept = quic::Connection; + + type Connection = H3Stream; + + async fn accept(&self) -> io::Result { + self.listener + .lock().await + .accept().await + .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "closed")) + } + + async fn connect(&self, accept: Self::Accept) -> io::Result { + let quic_conn = quic_h3::Connection::new(accept); + let conn = H3Conn::new(quic_conn).await.map_err(io::Error::other)?; + Ok(H3Stream(conn)) + } + + fn endpoint(&self) -> io::Result { + Ok(Endpoint::Quic(self.endpoint).with_tls(&self.tls)) + } +} + +impl H3Stream { + pub async fn accept(&mut self) -> io::Result> { + let handle = self.0.inner.conn.handle().clone(); + let ((parts, _), (tx, rx)) = match self.0.accept().await { + Ok(Some((req, stream))) => (req.into_parts(), stream.split()), + Ok(None) => return Ok(None), + Err(e) => { + if matches!(e.try_get_code().map(|c| c.value()), Some(0 | 0x100)) { + return Ok(None) + } + + return Err(io::Error::other(e)); + } + }; + + Ok(Some(H3Connection { handle, parts, tx: QuicTx(tx), rx: QuicRx(rx) })) + } +} + +impl QuicTx { + pub async fn send_response(&mut self, response: http::Response) -> io::Result<()> + where S: Stream> + { + let (parts, body) = response.into_parts(); + let response = http::Response::from_parts(parts, ()); + self.0.send_response(response).await.map_err(io::Error::other)?; + + let mut body = pin!(body); + while let Some(bytes) = body.next().await { + let bytes = bytes.map_err(io::Error::other)?; + self.0.send_data(bytes).await.map_err(io::Error::other)?; + } + + self.0.finish().await.map_err(io::Error::other) + } + + pub fn cancel(&mut self) { + self.0.stop_stream(h3::error::Code::H3_NO_ERROR); + } +} + +// FIXME: Expose certificates when possible. +impl Connection for H3Stream { + fn endpoint(&self) -> io::Result { + let addr = self.0.inner.conn.handle().remote_addr()?; + Ok(Endpoint::Quic(addr).assume_tls()) + } +} + +// FIXME: Expose certificates when possible. +impl Connection for H3Connection { + fn endpoint(&self) -> io::Result { + let addr = self.handle.remote_addr()?; + Ok(Endpoint::Quic(addr).assume_tls()) + } +} + +mod async_traits { + use std::io; + use std::pin::Pin; + use std::task::{ready, Context, Poll}; + + use super::{Bytes, QuicRx}; + use crate::listener::AsyncCancel; + + use futures::Stream; + use s2n_quic_h3::h3; + + impl Stream for QuicRx { + type Item = io::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use bytes::Buf; + + match ready!(self.0.poll_recv_data(cx)) { + Ok(Some(mut buf)) => Poll::Ready(Some(Ok(buf.copy_to_bytes(buf.remaining())))), + Ok(None) => Poll::Ready(None), + Err(e) => Poll::Ready(Some(Err(io::Error::other(e)))), + } + } + } + + impl AsyncCancel for QuicRx { + fn poll_cancel(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + self.0.stop_sending(h3::error::Code::H3_NO_ERROR); + Poll::Ready(Ok(())) + } + } +} + +impl fmt::Debug for H3Stream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("H3Stream").finish() + } +} + +impl fmt::Debug for H3Connection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("H3Connection").finish() + } +} diff --git a/core/lib/src/listener/tcp.rs b/core/lib/src/listener/tcp.rs index c2e3fd9f..32b7e165 100644 --- a/core/lib/src/listener/tcp.rs +++ b/core/lib/src/listener/tcp.rs @@ -13,6 +13,10 @@ impl Bindable for std::net::SocketAddr { async fn bind(self) -> Result { TcpListener::bind(self).await } + + fn candidate_endpoint(&self) -> io::Result { + Ok(Endpoint::Tcp(*self)) + } } impl Listener for TcpListener { @@ -31,13 +35,13 @@ impl Listener for TcpListener { Ok(conn) } - fn socket_addr(&self) -> io::Result { + fn endpoint(&self) -> io::Result { self.local_addr().map(Endpoint::Tcp) } } impl Connection for TcpStream { - fn peer_address(&self) -> io::Result { + fn endpoint(&self) -> io::Result { self.peer_addr().map(Endpoint::Tcp) } } diff --git a/core/lib/src/listener/tls.rs b/core/lib/src/listener/tls.rs index ce2b53ff..1eea9fc1 100644 --- a/core/lib/src/listener/tls.rs +++ b/core/lib/src/listener/tls.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use serde::Deserialize; use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::TlsAcceptor; use crate::tls::{TlsConfig, Error}; @@ -27,7 +28,7 @@ pub struct TlsBindable { } impl TlsConfig { - pub(crate) fn acceptor(&self) -> Result { + pub(crate) fn server_config(&self) -> Result { let provider = rustls::crypto::CryptoProvider { cipher_suites: self.ciphers().map(|c| c.into()).collect(), ..rustls::crypto::ring::default_provider() @@ -64,52 +65,60 @@ impl TlsConfig { tls_config.alpn_protocols.insert(0, b"h2".to_vec()); } - Ok(TlsAcceptor::from(Arc::new(tls_config))) + Ok(tls_config) } } -impl Bindable for TlsBindable { +impl Bindable for TlsBindable + where I::Listener: Listener::Connection>, + ::Connection: AsyncRead + AsyncWrite +{ type Listener = TlsListener; type Error = Error; async fn bind(self) -> Result { Ok(TlsListener { - acceptor: self.tls.acceptor()?, + acceptor: TlsAcceptor::from(Arc::new(self.tls.server_config()?)), listener: self.inner.bind().await.map_err(|e| Error::Bind(Box::new(e)))?, config: self.tls, }) } + + fn candidate_endpoint(&self) -> io::Result { + let inner = self.inner.candidate_endpoint()?; + Ok(inner.with_tls(&self.tls)) + } } -impl Listener for TlsListener - where L::Connection: Unpin +impl Listener for TlsListener + where L: Listener::Connection>, + L::Connection: AsyncRead + AsyncWrite { - type Accept = L::Accept; + type Accept = L::Connection; type Connection = TlsStream; async fn accept(&self) -> io::Result { - self.listener.accept().await + Ok(self.listener.accept().await?) } - async fn connect(&self, accept: L::Accept) -> io::Result { - let conn = self.listener.connect(accept).await?; + async fn connect(&self, conn: L::Connection) -> io::Result { self.acceptor.accept(conn).await } - fn socket_addr(&self) -> io::Result { - Ok(self.listener.socket_addr()?.with_tls(self.config.clone())) + fn endpoint(&self) -> io::Result { + Ok(self.listener.endpoint()?.with_tls(&self.config)) } } -impl Connection for TlsStream { - fn peer_address(&self) -> io::Result { - Ok(self.get_ref().0.peer_address()?.assume_tls()) +impl Connection for TlsStream { + fn endpoint(&self) -> io::Result { + Ok(self.get_ref().0.endpoint()?.assume_tls()) } #[cfg(feature = "mtls")] - fn peer_certificates(&self) -> Option> { + fn certificates(&self) -> Option> { let cert_chain = self.get_ref().1.peer_certificates()?; Some(Certificates::from(cert_chain)) } diff --git a/core/lib/src/listener/unix.rs b/core/lib/src/listener/unix.rs index ea1a367e..b29fc71d 100644 --- a/core/lib/src/listener/unix.rs +++ b/core/lib/src/listener/unix.rs @@ -68,6 +68,10 @@ impl Bindable for UdsConfig { Ok(UdsListener { lock, listener, path: self.path, }) } + + fn candidate_endpoint(&self) -> io::Result { + Ok(Endpoint::Unix(self.path.clone())) + } } impl Listener for UdsListener { @@ -83,13 +87,13 @@ impl Listener for UdsListener { Ok(accept) } - fn socket_addr(&self) -> io::Result { + fn endpoint(&self) -> io::Result { self.listener.local_addr()?.try_into() } } impl Connection for UnixStream { - fn peer_address(&self) -> io::Result { + fn endpoint(&self) -> io::Result { self.local_addr()?.try_into() } } diff --git a/core/lib/src/local/asynchronous/client.rs b/core/lib/src/local/asynchronous/client.rs index 2a45a331..e42254b2 100644 --- a/core/lib/src/local/asynchronous/client.rs +++ b/core/lib/src/local/asynchronous/client.rs @@ -59,12 +59,12 @@ impl Client { tracked: bool, secure: bool, ) -> Result { - let mut listener = Endpoint::new("local client"); + let mut endpoint = Endpoint::new("local client"); if secure { - listener = listener.assume_tls(); + endpoint = endpoint.assume_tls(); } - let rocket = rocket.local_launch(listener).await?; + let rocket = rocket.local_launch(endpoint).await?; let cookies = RwLock::new(cookie::CookieJar::new()); Ok(Client { rocket, cookies, tracked }) } diff --git a/core/lib/src/local/request.rs b/core/lib/src/local/request.rs index 1ec07400..7a156e28 100644 --- a/core/lib/src/local/request.rs +++ b/core/lib/src/local/request.rs @@ -99,26 +99,27 @@ macro_rules! pub_request_impl { /// Set the remote address of this request to `address`. /// - /// `address` may be any type that [can be converted into a `ListenerAddr`]. + /// `address` may be any type that [can be converted into a `Endpoint`]. /// If `address` fails to convert, the remote is left unchanged. /// - /// [can be converted into a `ListenerAddr`]: crate::listener::ListenerAddr#conversions + /// [can be converted into a `Endpoint`]: crate::listener::Endpoint#conversions /// /// # Examples /// /// Set the remote address to "8.8.8.8:80": /// /// ```rust - /// use std::net::{SocketAddrV4, Ipv4Addr}; + /// use std::net::Ipv4Addr; /// #[doc = $import] /// /// # Client::_test(|_, request, _| { /// let request: LocalRequest = request; - /// let req = request.remote("8.8.8.8:80"); + /// let req = request.remote("tcp:8.8.8.8:80"); /// - /// let addr = SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8).into(), 80); - /// assert_eq!(req.inner().remote().unwrap(), &addr); + /// let remote = req.inner().remote().unwrap().tcp().unwrap(); + /// assert_eq!(remote.ip(), Ipv4Addr::new(8, 8, 8, 8)); + /// assert_eq!(remote.port(), 80); /// # }); /// ``` #[inline] diff --git a/core/lib/src/phase.rs b/core/lib/src/phase.rs index b38deeea..24f8c265 100644 --- a/core/lib/src/phase.rs +++ b/core/lib/src/phase.rs @@ -2,7 +2,8 @@ use state::TypeMap; use figment::Figment; use crate::listener::Endpoint; -use crate::{Catcher, Config, Rocket, Route, Shutdown}; +use crate::shutdown::Stages; +use crate::{Catcher, Config, Rocket, Route}; use crate::router::Router; use crate::fairing::Fairings; @@ -99,7 +100,7 @@ phases! { pub(crate) figment: Figment, pub(crate) config: Config, pub(crate) state: TypeMap![Send + Sync], - pub(crate) shutdown: Shutdown, + pub(crate) shutdown: Stages, } /// The final launch [`Phase`]. See [Rocket#orbit](`Rocket#orbit`) for @@ -113,7 +114,7 @@ phases! { pub(crate) figment: Figment, pub(crate) config: Config, pub(crate) state: TypeMap![Send + Sync], - pub(crate) shutdown: Shutdown, - pub(crate) endpoint: Endpoint, + pub(crate) shutdown: Stages, + pub(crate) endpoints: Vec, } } diff --git a/core/lib/src/request/from_request.rs b/core/lib/src/request/from_request.rs index 25c33e73..a87b3bd4 100644 --- a/core/lib/src/request/from_request.rs +++ b/core/lib/src/request/from_request.rs @@ -501,7 +501,7 @@ impl<'r> FromRequest<'r> for std::net::SocketAddr { async fn from_request(request: &'r Request<'_>) -> Outcome { request.remote() - .and_then(|r| r.tcp()) + .and_then(|r| r.socket_addr()) .or_forward(Status::InternalServerError) } } diff --git a/core/lib/src/request/request.rs b/core/lib/src/request/request.rs index 5ff7e4b7..0c92e671 100644 --- a/core/lib/src/request/request.rs +++ b/core/lib/src/request/request.rs @@ -39,7 +39,7 @@ pub struct Request<'r> { /// Information derived from an incoming connection, if any. #[derive(Clone, Default)] pub(crate) struct ConnectionMeta { - pub peer_address: Option>, + pub peer_endpoint: Option, #[cfg_attr(not(feature = "mtls"), allow(dead_code))] pub peer_certs: Option>>, } @@ -47,8 +47,8 @@ pub(crate) struct ConnectionMeta { impl From<&C> for ConnectionMeta { fn from(conn: &C) -> Self { ConnectionMeta { - peer_address: conn.peer_address().ok().map(Arc::new), - peer_certs: conn.peer_certificates().map(|c| c.into_owned()).map(Arc::new), + peer_endpoint: conn.endpoint().ok(), + peer_certs: conn.certificates().map(|c| c.into_owned()).map(Arc::new), } } } @@ -316,20 +316,21 @@ impl<'r> Request<'r> { /// # Example /// /// ```rust - /// use std::net::{SocketAddrV4, Ipv4Addr}; + /// use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + /// use rocket::listener::Endpoint; /// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap(); /// # let mut req = c.get("/"); /// # let request = req.inner_mut(); /// /// assert_eq!(request.remote(), None); /// - /// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8000); - /// request.set_remote(localhost); - /// assert_eq!(request.remote().unwrap(), &localhost); + /// let localhost = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8111); + /// request.set_remote(Endpoint::Tcp(localhost)); + /// assert_eq!(request.remote().unwrap().tcp().unwrap(), localhost); /// ``` #[inline(always)] pub fn remote(&self) -> Option<&Endpoint> { - self.connection.peer_address.as_deref() + self.connection.peer_endpoint.as_ref() } /// Sets the remote address of `self` to `address`. @@ -339,20 +340,21 @@ impl<'r> Request<'r> { /// Set the remote address to be 127.0.0.1:8111: /// /// ```rust - /// use std::net::{SocketAddrV4, Ipv4Addr}; + /// use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + /// use rocket::listener::Endpoint; /// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap(); /// # let mut req = c.get("/"); /// # let request = req.inner_mut(); /// /// assert_eq!(request.remote(), None); /// - /// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8111); - /// request.set_remote(localhost); - /// assert_eq!(request.remote().unwrap(), &localhost); + /// let localhost = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8111); + /// request.set_remote(Endpoint::Tcp(localhost)); + /// assert_eq!(request.remote().unwrap().tcp().unwrap(), localhost); /// ``` #[inline(always)] - pub fn set_remote>(&mut self, address: A) { - self.connection.peer_address = Some(Arc::new(address.into())); + pub fn set_remote(&mut self, endpoint: Endpoint) { + self.connection.peer_endpoint = Some(endpoint.into()); } /// Returns the IP address of the configured @@ -491,14 +493,15 @@ impl<'r> Request<'r> { /// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap(); /// # let mut req = c.get("/"); /// # let request = req.inner_mut(); - /// # use std::net::{SocketAddrV4, Ipv4Addr}; + /// # use std::net::{SocketAddr, IpAddr, Ipv4Addr}; + /// # use rocket::listener::Endpoint; /// /// // starting without an "X-Real-IP" header or remote address /// assert!(request.client_ip().is_none()); /// /// // add a remote address; this is done by Rocket automatically - /// let localhost_9190 = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 9190); - /// request.set_remote(localhost_9190); + /// let localhost_9190 = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9190); + /// request.set_remote(Endpoint::Tcp(localhost_9190)); /// assert_eq!(request.client_ip().unwrap(), Ipv4Addr::LOCALHOST); /// /// // now with an X-Real-IP header, the default value for `ip_header`. @@ -507,7 +510,7 @@ impl<'r> Request<'r> { /// ``` #[inline] pub fn client_ip(&self) -> Option { - self.real_ip().or_else(|| Some(self.remote()?.tcp()?.ip())) + self.real_ip().or_else(|| self.remote()?.ip()) } /// Returns a wrapped borrow to the cookies in `self`. diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index 40570e7b..6d1706e7 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -1,14 +1,17 @@ use std::fmt; use std::ops::{Deref, DerefMut}; +use std::sync::Arc; +use std::time::Duration; use yansi::Paint; use either::Either; use figment::{Figment, Provider}; +use tokio::io::{AsyncRead, AsyncWrite}; -use crate::{Catcher, Config, Route, Shutdown, sentinel, shield::Shield}; -use crate::listener::{Endpoint, Bindable, DefaultListener}; +use crate::shutdown::{Stages, Shutdown}; +use crate::{sentinel, shield::Shield, Catcher, Config, Route}; +use crate::listener::{Bindable, DefaultListener, Endpoint, Listener}; use crate::router::Router; -use crate::util::TripWire; use crate::fairing::{Fairing, Fairings}; use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting}; use crate::phase::{Stateful, StateRef, State}; @@ -575,11 +578,11 @@ impl Rocket { // Ignite the rocket. let rocket: Rocket = Rocket(Igniting { - router, config, - shutdown: Shutdown(TripWire::new()), + shutdown: Stages::new(), figment: self.0.figment, fairings: self.0.fairings, state: self.0.state, + router, config, }); // Query the sentinels, abort if requested. @@ -630,7 +633,7 @@ impl Rocket { /// A completed graceful shutdown resolves the future returned by /// [`Rocket::launch()`]. If [`Shutdown::notify()`] is called _before_ an /// instance is launched, it will be immediately shutdown after liftoff. See - /// [`Shutdown`] and [`config::Shutdown`](crate::config::Shutdown) for + /// [`Shutdown`] and [`ShutdownConfig`](crate::config::ShutdownConfig) for /// details on graceful shutdown. /// /// # Example @@ -657,12 +660,12 @@ impl Rocket { /// } /// ``` pub fn shutdown(&self) -> Shutdown { - self.shutdown.clone() + self.shutdown.start.clone() } - pub(crate) fn into_orbit(self, address: Endpoint) -> Rocket { + pub(crate) fn into_orbit(self, endpoints: Vec) -> Rocket { Rocket(Orbiting { - endpoint: address, + endpoints, router: self.0.router, fairings: self.0.fairings, figment: self.0.figment, @@ -672,8 +675,8 @@ impl Rocket { }) } - async fn _local_launch(self, addr: Endpoint) -> Rocket { - let rocket = self.into_orbit(addr); + async fn _local_launch(self, endpoint: Endpoint) -> Rocket { + let rocket = self.into_orbit(vec![endpoint]); Rocket::liftoff(&rocket).await; rocket } @@ -687,13 +690,72 @@ impl Rocket { }) } - async fn _launch_on(self, bindable: B) -> Result, Error> { - let listener = bindable.bind().await.map_err(|e| ErrorKind::Bind(Box::new(e)))?; - self.serve(listener).await + async fn _launch_on(self, bindable: B) -> Result, Error> + where ::Connection: AsyncRead + AsyncWrite + { + let rocket = self.bind_and_serve(bindable, |rocket| async move { + let rocket = Arc::new(rocket); + + rocket.shutdown.spawn_listener(&rocket.config.shutdown); + if let Err(e) = tokio::spawn(Rocket::liftoff(rocket.clone())).await { + let rocket = rocket.try_wait_shutdown().await; + return Err(ErrorKind::Liftoff(rocket, Box::new(e)).into()); + } + + Ok(rocket) + }).await?; + + Ok(rocket.try_wait_shutdown().await.map_err(ErrorKind::Shutdown)?) } } impl Rocket { + /// Rocket wraps all connections in a `CancellableIo` struct, an internal + /// structure that gracefully closes I/O when it receives a signal. That + /// signal is the `shutdown` future. When the future resolves, + /// `CancellableIo` begins to terminate in grace, mercy, and finally force + /// close phases. Since all connections are wrapped in `CancellableIo`, this + /// eventually ends all I/O. + /// + /// At that point, unless a user spawned an infinite, stand-alone task that + /// isn't monitoring `Shutdown`, all tasks should resolve. This means that + /// all instances of the shared `Arc` are dropped and we can return + /// the owned instance of `Rocket`. + /// + /// Unfortunately, the Hyper `server` future resolves as soon as it has + /// finished processing requests without respect for ongoing responses. That + /// is, `server` resolves even when there are running tasks that are + /// generating a response. So, `server` resolving implies little to nothing + /// about the state of connections. As a result, we depend on the timing of + /// grace + mercy + some buffer to determine when all connections should be + /// closed, thus all tasks should be complete, thus all references to + /// `Arc` should be dropped and we can get back a unique reference. + async fn try_wait_shutdown(self: Arc) -> Result, Arc> { + info!("Shutting down. Waiting for shutdown fairings and pending I/O..."); + tokio::spawn({ + let rocket = self.clone(); + async move { rocket.fairings.handle_shutdown(&*rocket).await } + }); + + let config = &self.config.shutdown; + let wait = Duration::from_micros(250); + for period in [wait, config.grace(), wait, config.mercy(), wait * 4] { + if Arc::strong_count(&self) == 1 { break } + tokio::time::sleep(period).await; + } + + match Arc::try_unwrap(self) { + Ok(rocket) => { + info!("Graceful shutdown completed successfully."); + Ok(rocket.into_ignite()) + } + Err(rocket) => { + warn!("Shutdown failed: outstanding background I/O."); + Err(rocket) + } + } + } + pub(crate) fn into_ignite(self) -> Rocket { Rocket(Igniting { router: self.0.router, @@ -717,7 +779,7 @@ impl Rocket { launch_info!("{}{} {}", "🚀 ".emoji(), "Rocket has launched on".bold().primary().linger(), - rocket.endpoint().underline()); + rocket.endpoints[0].underline()); } /// Returns the finalized, active configuration. This is guaranteed to @@ -742,8 +804,8 @@ impl Rocket { &self.config } - pub fn endpoint(&self) -> &Endpoint { - &self.endpoint + pub fn endpoints(&self) -> impl Iterator { + self.endpoints.iter() } /// Returns a handle which can be used to trigger a shutdown and detect a @@ -751,8 +813,8 @@ impl Rocket { /// /// A completed graceful shutdown resolves the future returned by /// [`Rocket::launch()`]. See [`Shutdown`] and - /// [`config::Shutdown`](crate::config::Shutdown) for details on graceful - /// shutdown. + /// [`ShutdownConfig`](crate::config::ShutdownConfig) for details on + /// graceful shutdown. /// /// # Example /// @@ -774,7 +836,7 @@ impl Rocket { /// } /// ``` pub fn shutdown(&self) -> Shutdown { - self.shutdown.clone() + self.shutdown.start.clone() } } @@ -879,10 +941,10 @@ impl Rocket

{ } } - pub(crate) async fn local_launch(self, l: Endpoint) -> Result, Error> { + pub(crate) async fn local_launch(self, e: Endpoint) -> Result, Error> { let rocket = match self.0.into_state() { - State::Build(s) => Rocket::from(s).ignite().await?._local_launch(l).await, - State::Ignite(s) => Rocket::from(s)._local_launch(l).await, + State::Build(s) => Rocket::from(s).ignite().await?._local_launch(e).await, + State::Ignite(s) => Rocket::from(s)._local_launch(e).await, State::Orbit(s) => Rocket::from(s) }; @@ -941,7 +1003,9 @@ impl Rocket

{ } } - pub async fn launch_on(self, bindable: B) -> Result, Error> { + pub async fn launch_on(self, bindable: B) -> Result, Error> + where ::Connection: AsyncRead + AsyncWrite + { match self.0.into_state() { State::Build(s) => Rocket::from(s).ignite().await?._launch_on(bindable).await, State::Ignite(s) => Rocket::from(s)._launch_on(bindable).await, diff --git a/core/lib/src/router/collider.rs b/core/lib/src/router/collider.rs index acf4d7c9..d6ada473 100644 --- a/core/lib/src/router/collider.rs +++ b/core/lib/src/router/collider.rs @@ -213,8 +213,8 @@ mod tests { use std::str::FromStr; use super::*; - use crate::route::{Route, dummy_handler}; - use crate::http::{Method, Method::*, MediaType}; + use crate::route::dummy_handler; + use crate::http::{Method, Method::*}; fn dummy_route(ranked: bool, method: impl Into>, uri: &'static str) -> Route { let method = method.into().unwrap_or(Get); diff --git a/core/lib/src/router/router.rs b/core/lib/src/router/router.rs index 5617f4fb..a9fb4600 100644 --- a/core/lib/src/router/router.rs +++ b/core/lib/src/router/router.rs @@ -102,7 +102,7 @@ mod test { use crate::route::dummy_handler; use crate::local::blocking::Client; - use crate::http::{Method, Method::*, uri::Origin}; + use crate::http::{Method::*, uri::Origin}; impl Router { fn has_collisions(&self) -> bool { diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index 3fbe2ae7..1f6c0f0c 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -6,33 +6,35 @@ use std::time::Duration; use hyper::service::service_fn; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use hyper_util::server::conn::auto::Builder; -use futures::{Future, TryFutureExt, future::{select, Either::*}}; -use tokio::time::sleep; +use futures::{Future, TryFutureExt, future::Either::*}; +use tokio::io::{AsyncRead, AsyncWrite}; -use crate::{Request, Rocket, Orbit, Data, Ignite}; +use crate::{Ignite, Orbit, Request, Rocket}; use crate::request::ConnectionMeta; use crate::erased::{ErasedRequest, ErasedResponse, ErasedIoHandler}; -use crate::listener::{Listener, CancellableExt, BouncedExt}; -use crate::error::{Error, ErrorKind}; -use crate::data::IoStream; -use crate::util::ReaderStream; +use crate::listener::{Bindable, BouncedExt, CancellableExt, Listener}; +use crate::error::{log_server_error, ErrorKind}; +use crate::data::{IoStream, RawStream}; +use crate::util::{spawn_inspect, FutureExt, ReaderStream}; use crate::http::Status; +type Result = std::result::Result; + impl Rocket { - async fn service( + async fn service Into>>( self: Arc, - mut req: hyper::Request, + parts: http::request::Parts, + stream: T, + upgrade: Option, connection: ConnectionMeta, ) -> Result>, http::Error> { - let upgrade = hyper::upgrade::on(&mut req); - let (parts, incoming) = req.into_parts(); + let alt_svc = self.alt_svc(); let request = ErasedRequest::new(self, parts, |rocket, parts| { Request::from_hyp(rocket, parts, connection).unwrap_or_else(|e| e) }); let mut response = request.into_response( - incoming, - |incoming| Data::from(incoming), + stream, |rocket, request, data| Box::pin(rocket.preprocess(request, data)), |token, rocket, request, data| Box::pin(async move { if !request.errors.is_empty() { @@ -46,7 +48,7 @@ impl Rocket { ).await; let io_handler = response.to_io_handler(Rocket::extract_io_handler); - if let Some(handler) = io_handler { + if let (Some(handler), Some(upgrade)) = (io_handler, upgrade) { let upgrade = upgrade.map_ok(IoStream::from).map_err(io::Error::other); tokio::task::spawn(io_handler_task(upgrade, handler)); } @@ -58,12 +60,28 @@ impl Rocket { } if let Some(size) = response.inner().body().preset_size() { - builder = builder.header("Content-Length", size); + builder = builder.header(http::header::CONTENT_TYPE, size); + } + + if let Some(alt_svc) = alt_svc { + let value = http::HeaderValue::from_static(alt_svc); + builder = builder.header(http::header::ALT_SVC, value); } let chunk_size = response.inner().body().max_chunk_size(); builder.body(ReaderStream::with_capacity(response, chunk_size)) } + + fn alt_svc(&self) -> Option<&'static str> { + cfg!(feature = "http3-preview").then(|| { + static ALT_SVC: state::InitCell> = state::InitCell::new(); + + ALT_SVC.get_or_init(|| { + let addr = self.endpoints().find_map(|v| v.quic())?; + Some(format!("h3=\":{}\"", addr.port())) + }).as_deref() + })? + } } async fn io_handler_task(stream: S, mut handler: ErasedIoHandler) @@ -84,8 +102,51 @@ async fn io_handler_task(stream: S, mut handler: ErasedIoHandler) } impl Rocket { - pub(crate) async fn serve(self, listener: L) -> Result - where L: Listener + 'static + pub(crate) async fn bind_and_serve( + self, + bindable: B, + post_bind_callback: impl FnOnce(Rocket) -> R, + ) -> Result>> + where B: Bindable, + ::Connection: AsyncRead + AsyncWrite, + R: Future>>> + { + let binding_endpoint = bindable.candidate_endpoint().ok(); + let h12listener = bindable.bind() + .map_err(|e| ErrorKind::Bind(binding_endpoint, Box::new(e))) + .await?; + + let endpoint = h12listener.endpoint()?; + #[cfg(feature = "http3-preview")] + if let (Some(addr), Some(tls)) = (endpoint.tcp(), endpoint.tls_config()) { + let h3listener = crate::listener::quic::QuicListener::bind(addr, tls.clone()).await?; + let rocket = self.into_orbit(vec![h3listener.endpoint()?, endpoint]); + let rocket = post_bind_callback(rocket).await?; + + let http12 = tokio::task::spawn(rocket.clone().serve12(h12listener)); + let http3 = tokio::task::spawn(rocket.clone().serve3(h3listener)); + let (r1, r2) = tokio::join!(http12, http3); + r1.map_err(|e| ErrorKind::Liftoff(Err(rocket.clone()), Box::new(e)))??; + r2.map_err(|e| ErrorKind::Liftoff(Err(rocket.clone()), Box::new(e)))??; + return Ok(rocket); + } + + if cfg!(feature = "http3-preview") { + warn!("HTTP/3 cannot start without a valid TCP + TLS configuration."); + info_!("Falling back to HTTP/1 + HTTP/2 server."); + } + + let rocket = self.into_orbit(vec![endpoint]); + let rocket = post_bind_callback(rocket).await?; + rocket.clone().serve12(h12listener).await?; + Ok(rocket) + } +} + +impl Rocket { + pub(crate) async fn serve12(self: Arc, listener: L) -> Result<()> + where L: Listener + 'static, + L::Connection: AsyncRead + AsyncWrite { let mut builder = Builder::new(TokioExecutor::new()); let keep_alive = Duration::from_secs(self.config.keep_alive.into()); @@ -106,75 +167,64 @@ impl Rocket { } } - let listener = listener.bounced().cancellable(self.shutdown(), &self.config.shutdown); - let rocket = Arc::new(self.into_orbit(listener.socket_addr()?)); - let _ = tokio::spawn(Rocket::liftoff(rocket.clone())).await; + let (listener, server) = (Arc::new(listener.bounced()), Arc::new(builder)); + while let Some(accept) = listener.accept().unless(self.shutdown()).await? { + let (listener, rocket, server) = (listener.clone(), self.clone(), server.clone()); + spawn_inspect(|e| log_server_error(&**e), async move { + let conn = listener.connect(accept).io_unless(rocket.shutdown()).await?; + let meta = ConnectionMeta::from(&conn); + let service = service_fn(|mut req| { + let upgrade = hyper::upgrade::on(&mut req); + let (parts, incoming) = req.into_parts(); + rocket.clone().service(parts, incoming, Some(upgrade), meta.clone()) + }); - let (server, listener) = (Arc::new(builder), Arc::new(listener)); - while let Some(accept) = listener.accept_next().await { - let (listener, rocket, server) = (listener.clone(), rocket.clone(), server.clone()); - tokio::spawn({ - let result = async move { - let conn = TokioIo::new(listener.connect(accept).await?); - let meta = ConnectionMeta::from(conn.inner()); - let service = service_fn(|req| rocket.clone().service(req, meta.clone())); - let serve = pin!(server.serve_connection_with_upgrades(conn, service)); - match select(serve, rocket.shutdown()).await { - Left((result, _)) => result, - Right((_, mut conn)) => { - conn.as_mut().graceful_shutdown(); - conn.await - } - } - }; - - result.inspect_err(crate::error::log_server_error) + let io = TokioIo::new(conn.cancellable(rocket.shutdown.clone())); + let mut server = pin!(server.serve_connection_with_upgrades(io, service)); + match server.as_mut().or(rocket.shutdown()).await { + Left(result) => result, + Right(()) => { + server.as_mut().graceful_shutdown(); + server.await + }, + } }); } - // Rocket wraps all connections in a `CancellableIo` struct, an internal - // structure that gracefully closes I/O when it receives a signal. That - // signal is the `shutdown` future. When the future resolves, - // `CancellableIo` begins to terminate in grace, mercy, and finally - // force close phases. Since all connections are wrapped in - // `CancellableIo`, this eventually ends all I/O. - // - // At that point, unless a user spawned an infinite, stand-alone task - // that isn't monitoring `Shutdown`, all tasks should resolve. This - // means that all instances of the shared `Arc` are dropped and - // we can return the owned instance of `Rocket`. - // - // Unfortunately, the Hyper `server` future resolves as soon as it has - // finished processing requests without respect for ongoing responses. - // That is, `server` resolves even when there are running tasks that are - // generating a response. So, `server` resolving implies little to - // nothing about the state of connections. As a result, we depend on the - // timing of grace + mercy + some buffer to determine when all - // connections should be closed, thus all tasks should be complete, thus - // all references to `Arc` should be dropped and we can get back - // a unique reference. - info!("Shutting down. Waiting for shutdown fairings and pending I/O..."); - tokio::spawn({ - let rocket = rocket.clone(); - async move { rocket.fairings.handle_shutdown(&*rocket).await } - }); + Ok(()) + } - let config = &rocket.config.shutdown; - let wait = Duration::from_micros(250); - for period in [wait, config.grace(), wait, config.mercy(), wait * 4] { - if Arc::strong_count(&rocket) == 1 { break } - sleep(period).await; + #[cfg(feature = "http3-preview")] + async fn serve3(self: Arc, listener: crate::listener::quic::QuicListener) -> Result<()> { + let rocket = self.clone(); + let listener = Arc::new(listener.bounced()); + while let Some(accept) = listener.accept().unless(rocket.shutdown()).await? { + let (listener, rocket) = (listener.clone(), rocket.clone()); + spawn_inspect(|e: &io::Error| log_server_error(e), async move { + let mut stream = listener.connect(accept).io_unless(rocket.shutdown()).await?; + while let Some(mut conn) = stream.accept().io_unless(rocket.shutdown()).await? { + let rocket = rocket.clone(); + spawn_inspect(|e: &io::Error| log_server_error(e), async move { + let meta = ConnectionMeta::from(&conn); + let rx = conn.rx.cancellable(rocket.shutdown.clone()); + let response = rocket.clone() + .service(conn.parts, rx, None, ConnectionMeta::from(meta)) + .map_err(io::Error::other) + .io_unless(rocket.shutdown.mercy.clone()) + .await?; + + let grace = rocket.shutdown.grace.clone(); + match conn.tx.send_response(response).or(grace).await { + Left(result) => result, + Right(_) => Ok(conn.tx.cancel()), + } + }); + } + + Ok(()) + }); } - match Arc::try_unwrap(rocket) { - Ok(rocket) => { - info!("Graceful shutdown completed successfully."); - Ok(rocket.into_ignite()) - } - Err(rocket) => { - warn!("Shutdown failed: outstanding background I/O."); - Err(Error::new(ErrorKind::Shutdown(rocket))) - } - } + Ok(()) } } diff --git a/core/lib/src/shield/shield.rs b/core/lib/src/shield/shield.rs index f3a3aeb2..65e0ffcd 100644 --- a/core/lib/src/shield/shield.rs +++ b/core/lib/src/shield/shield.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; use std::sync::atomic::{AtomicBool, Ordering}; -use state::InitCell; use yansi::Paint; use crate::{Rocket, Request, Response, Orbit, Config}; @@ -68,11 +67,18 @@ use crate::shield::*; /// policy. pub struct Shield { /// Enabled policies where the key is the header name. - policies: HashMap<&'static UncasedStr, Box>, + policies: HashMap<&'static UncasedStr, Header<'static>>, /// Whether to enforce HSTS even though the user didn't enable it. force_hsts: AtomicBool, - /// Headers pre-rendered at liftoff from the configured policies. - rendered: InitCell>>, +} + +impl Clone for Shield { + fn clone(&self) -> Self { + Self { + policies: self.policies.clone(), + force_hsts: AtomicBool::from(self.force_hsts.load(Ordering::Acquire)), + } + } } impl Default for Shield { @@ -111,7 +117,6 @@ impl Shield { Shield { policies: HashMap::new(), force_hsts: AtomicBool::new(false), - rendered: InitCell::new(), } } @@ -129,8 +134,7 @@ impl Shield { /// let shield = Shield::new().enable(NoSniff::default()); /// ``` pub fn enable(mut self, policy: P) -> Self { - self.rendered = InitCell::new(); - self.policies.insert(P::NAME.into(), Box::new(policy)); + self.policies.insert(P::NAME.into(), policy.header()); self } @@ -145,7 +149,6 @@ impl Shield { /// let shield = Shield::default().disable::(); /// ``` pub fn disable(mut self) -> Self { - self.rendered = InitCell::new(); self.policies.remove(UncasedStr::new(P::NAME)); self } @@ -172,20 +175,6 @@ impl Shield { pub fn is_enabled(&self) -> bool { self.policies.contains_key(UncasedStr::new(P::NAME)) } - - fn headers(&self) -> &[Header<'static>] { - self.rendered.get_or_init(|| { - let mut headers: Vec<_> = self.policies.values() - .map(|p| p.header()) - .collect(); - - if self.force_hsts.load(Ordering::Acquire) { - headers.push(Policy::header(&Hsts::default())); - } - - headers - }) - } } #[crate::async_trait] @@ -198,7 +187,7 @@ impl Fairing for Shield { } async fn on_liftoff(&self, rocket: &Rocket) { - let force_hsts = rocket.endpoint().is_tls() + let force_hsts = rocket.endpoints().all(|v| v.is_tls()) && rocket.figment().profile() != Config::DEBUG_PROFILE && !self.is_enabled::(); @@ -206,10 +195,10 @@ impl Fairing for Shield { self.force_hsts.store(true, Ordering::Release); } - if !self.headers().is_empty() { + if !self.policies.is_empty() { info!("{}{}:", "🛡️ ".emoji(), "Shield".magenta()); - for header in self.headers() { + for header in self.policies.values() { info_!("{}: {}", header.name(), header.value().primary()); } @@ -224,7 +213,7 @@ impl Fairing for Shield { async fn on_response<'r>(&self, _: &'r Request<'_>, response: &mut Response<'r>) { // Set all of the headers in `self.policies` in `response` as long as // the header is not already in the response. - for header in self.headers() { + for header in self.policies.values() { if response.headers().contains(header.name()) { warn!("Shield: response contains a '{}' header.", header.name()); warn_!("Refusing to overwrite existing header."); diff --git a/core/lib/src/config/shutdown.rs b/core/lib/src/shutdown/config.rs similarity index 84% rename from core/lib/src/config/shutdown.rs rename to core/lib/src/shutdown/config.rs index 2353a4fb..fc4cb9de 100644 --- a/core/lib/src/config/shutdown.rs +++ b/core/lib/src/shutdown/config.rs @@ -6,60 +6,7 @@ use std::collections::HashSet; use futures::stream::Stream; use serde::{Deserialize, Serialize}; -/// A Unix signal for triggering graceful shutdown. -/// -/// Each variant corresponds to a Unix process signal which can be used to -/// trigger a graceful shutdown. See [`Shutdown`] for details. -/// -/// ## (De)serialization -/// -/// A `Sig` variant serializes and deserializes as a lowercase string equal to -/// the name of the variant: `"alrm"` for [`Sig::Alrm`], `"chld"` for -/// [`Sig::Chld`], and so on. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -#[cfg_attr(nightly, doc(cfg(unix)))] -pub enum Sig { - /// The `SIGALRM` Unix signal. - Alrm, - /// The `SIGCHLD` Unix signal. - Chld, - /// The `SIGHUP` Unix signal. - Hup, - /// The `SIGINT` Unix signal. - Int, - /// The `SIGIO` Unix signal. - Io, - /// The `SIGPIPE` Unix signal. - Pipe, - /// The `SIGQUIT` Unix signal. - Quit, - /// The `SIGTERM` Unix signal. - Term, - /// The `SIGUSR1` Unix signal. - Usr1, - /// The `SIGUSR2` Unix signal. - Usr2 -} - -impl fmt::Display for Sig { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match self { - Sig::Alrm => "SIGALRM", - Sig::Chld => "SIGCHLD", - Sig::Hup => "SIGHUP", - Sig::Int => "SIGINT", - Sig::Io => "SIGIO", - Sig::Pipe => "SIGPIPE", - Sig::Quit => "SIGQUIT", - Sig::Term => "SIGTERM", - Sig::Usr1 => "SIGUSR1", - Sig::Usr2 => "SIGUSR2", - }; - - s.fmt(f) - } -} +use crate::shutdown::Sig; /// Graceful shutdown configuration. /// @@ -94,11 +41,13 @@ impl fmt::Display for Sig { /// /// Once a shutdown is triggered, Rocket stops accepting new connections and /// waits at most `grace` seconds before initiating connection shutdown. -/// Applications can `await` the [`Shutdown`](crate::Shutdown) future to detect +/// Applications can `await` the [`Shutdown`] future to detect /// a shutdown and cancel any server-initiated I/O, such as from [infinite /// responders](crate::response::stream#graceful-shutdown), to avoid abrupt I/O /// cancellation. /// +/// [`Shutdown`]: crate::Shutdown +/// /// # Mercy Period /// /// After the grace period has elapsed, Rocket initiates connection shutdown, @@ -125,7 +74,8 @@ impl fmt::Display for Sig { /// prevent _buggy_ code, such as an unintended infinite loop or unknown use of /// blocking I/O, from preventing shutdown. /// -/// This behavior can be disabled by setting [`Shutdown::force`] to `false`. +/// This behavior can be disabled by setting [`ShutdownConfig::force`] to +/// `false`. /// /// # Example /// @@ -169,13 +119,13 @@ impl fmt::Display for Sig { /// /// ```rust /// # use rocket::figment::{Figment, providers::{Format, Toml}}; -/// use rocket::config::{Config, Shutdown}; +/// use rocket::config::{Config, ShutdownConfig}; /// /// #[cfg(unix)] /// use rocket::config::Sig; /// /// let config = Config { -/// shutdown: Shutdown { +/// shutdown: ShutdownConfig { /// ctrlc: false, /// #[cfg(unix)] /// signals: { @@ -204,7 +154,7 @@ impl fmt::Display for Sig { /// } /// ``` #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct Shutdown { +pub struct ShutdownConfig { /// Whether `ctrl-c` (`SIGINT`) initiates a server shutdown. /// /// **default: `true`** @@ -245,9 +195,9 @@ pub struct Shutdown { /// _always_ be done using a public constructor or update syntax: /// /// ```rust - /// use rocket::config::Shutdown; + /// use rocket::config::ShutdownConfig; /// - /// let config = Shutdown { + /// let config = ShutdownConfig { /// grace: 5, /// mercy: 10, /// ..Default::default() @@ -258,7 +208,7 @@ pub struct Shutdown { pub __non_exhaustive: (), } -impl fmt::Display for Shutdown { +impl fmt::Display for ShutdownConfig { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "ctrlc = {}, force = {}, ", self.ctrlc, self.force)?; @@ -276,9 +226,9 @@ impl fmt::Display for Shutdown { } } -impl Default for Shutdown { +impl Default for ShutdownConfig { fn default() -> Self { - Shutdown { + ShutdownConfig { ctrlc: true, #[cfg(unix)] signals: { let mut set = HashSet::new(); set.insert(Sig::Term); set }, @@ -290,7 +240,7 @@ impl Default for Shutdown { } } -impl Shutdown { +impl ShutdownConfig { pub(crate) fn grace(&self) -> Duration { Duration::from_secs(self.grace as u64) } diff --git a/core/lib/src/shutdown.rs b/core/lib/src/shutdown/handle.rs similarity index 54% rename from core/lib/src/shutdown.rs rename to core/lib/src/shutdown/handle.rs index 43a667af..862a11df 100644 --- a/core/lib/src/shutdown.rs +++ b/core/lib/src/shutdown/handle.rs @@ -2,22 +2,22 @@ use std::future::Future; use std::task::{Context, Poll}; use std::pin::Pin; -use futures::FutureExt; +use futures::{FutureExt, StreamExt}; +use crate::shutdown::{ShutdownConfig, TripWire}; use crate::request::{FromRequest, Outcome, Request}; -use crate::util::TripWire; /// A request guard and future for graceful shutdown. /// /// A server shutdown is manually requested by calling [`Shutdown::notify()`] -/// or, if enabled, through [automatic triggers] like `Ctrl-C`. Rocket will stop accepting new -/// requests, finish handling any pending requests, wait a grace period before -/// cancelling any outstanding I/O, and return `Ok()` to the caller of -/// [`Rocket::launch()`]. Graceful shutdown is configured via -/// [`config::Shutdown`](crate::config::Shutdown). +/// or, if enabled, through [automatic triggers] like `Ctrl-C`. Rocket will stop +/// accepting new requests, finish handling any pending requests, wait a grace +/// period before cancelling any outstanding I/O, and return `Ok()` to the +/// caller of [`Rocket::launch()`]. Graceful shutdown is configured via +/// [`ShutdownConfig`](crate::config::ShutdownConfig). /// /// [`Rocket::launch()`]: crate::Rocket::launch() -/// [automatic triggers]: crate::config::Shutdown#triggers +/// [automatic triggers]: crate::shutdown::Shutdown#triggers /// /// # Detecting Shutdown /// @@ -65,15 +65,30 @@ use crate::util::TripWire; /// ``` #[derive(Debug, Clone)] #[must_use = "`Shutdown` does nothing unless polled or `notify`ed"] -pub struct Shutdown(pub(crate) TripWire); +pub struct Shutdown { + wire: TripWire, +} + +#[derive(Debug, Clone)] +pub struct Stages { + pub start: Shutdown, + pub grace: Shutdown, + pub mercy: Shutdown, +} impl Shutdown { + fn new() -> Self { + Shutdown { + wire: TripWire::new(), + } + } + /// Notify the application to shut down gracefully. /// /// This function returns immediately; pending requests will continue to run /// until completion or expiration of the grace period, which ever comes /// first, before the actual shutdown occurs. The grace period can be - /// configured via [`Shutdown::grace`](crate::config::Shutdown::grace). + /// configured via [`Shutdown::grace`](crate::config::ShutdownConfig::grace). /// /// ```rust /// # use rocket::*; @@ -85,9 +100,37 @@ impl Shutdown { /// "Shutting down..." /// } /// ``` - #[inline] - pub fn notify(self) { - self.0.trip(); + #[inline(always)] + pub fn notify(&self) { + self.wire.trip(); + } + + /// Returns `true` if `Shutdown::notify()` has already been called. + /// + /// # Example + /// + /// ```rust + /// # use rocket::*; + /// use rocket::Shutdown; + /// + /// #[get("/shutdown")] + /// fn shutdown(shutdown: Shutdown) { + /// shutdown.notify(); + /// assert!(shutdown.notified()); + /// } + /// ``` + #[must_use] + #[inline(always)] + pub fn notified(&self) -> bool { + self.wire.tripped() + } +} + +impl Future for Shutdown { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.wire.poll_unpin(cx) } } @@ -101,11 +144,41 @@ impl<'r> FromRequest<'r> for Shutdown { } } -impl Future for Shutdown { - type Output = (); +impl Stages { + pub fn new() -> Self { + Stages { + start: Shutdown::new(), + grace: Shutdown::new(), + mercy: Shutdown::new(), + } + } - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.0.poll_unpin(cx) + pub(crate) fn spawn_listener(&self, config: &ShutdownConfig) { + use futures::stream; + use futures::future::{select, Either}; + + let mut signal = match config.signal_stream() { + Some(stream) => Either::Left(stream.chain(stream::pending())), + None => Either::Right(stream::pending()), + }; + + let start = self.start.clone(); + let (grace, grace_duration) = (self.grace.clone(), config.grace()); + let (mercy, mercy_duration) = (self.mercy.clone(), config.mercy()); + tokio::spawn(async move { + if let Either::Left((sig, start)) = select(signal.next(), start).await { + warn!("Received {}. Shutdown started.", sig.unwrap()); + start.notify(); + } + + tokio::time::sleep(grace_duration).await; + warn!("Shutdown grace period elapsed. Shutting down I/O."); + grace.notify(); + + tokio::time::sleep(mercy_duration).await; + warn!("Mercy period elapsed. Terminating I/O."); + mercy.notify(); + }); } } diff --git a/core/lib/src/shutdown/mod.rs b/core/lib/src/shutdown/mod.rs new file mode 100644 index 00000000..d68fddf3 --- /dev/null +++ b/core/lib/src/shutdown/mod.rs @@ -0,0 +1,13 @@ +//! Shutdown configuration and notification handle. + +mod tripwire; +mod handle; +mod sig; +mod config; + +pub(crate) use tripwire::TripWire; +pub(crate) use handle::Stages; + +pub use config::ShutdownConfig; +pub use handle::Shutdown; +pub use sig::Sig; diff --git a/core/lib/src/shutdown/sig.rs b/core/lib/src/shutdown/sig.rs new file mode 100644 index 00000000..2f20b7a4 --- /dev/null +++ b/core/lib/src/shutdown/sig.rs @@ -0,0 +1,58 @@ +use std::fmt; + +use serde::{Deserialize, Serialize}; + +/// A Unix signal for triggering graceful shutdown. +/// +/// Each variant corresponds to a Unix process signal which can be used to +/// trigger a graceful shutdown. See [`Shutdown`](crate::Shutdown) for details. +/// +/// ## (De)serialization +/// +/// A `Sig` variant serializes and deserializes as a lowercase string equal to +/// the name of the variant: `"alrm"` for [`Sig::Alrm`], `"chld"` for +/// [`Sig::Chld`], and so on. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +#[cfg_attr(nightly, doc(cfg(unix)))] +pub enum Sig { + /// The `SIGALRM` Unix signal. + Alrm, + /// The `SIGCHLD` Unix signal. + Chld, + /// The `SIGHUP` Unix signal. + Hup, + /// The `SIGINT` Unix signal. + Int, + /// The `SIGIO` Unix signal. + Io, + /// The `SIGPIPE` Unix signal. + Pipe, + /// The `SIGQUIT` Unix signal. + Quit, + /// The `SIGTERM` Unix signal. + Term, + /// The `SIGUSR1` Unix signal. + Usr1, + /// The `SIGUSR2` Unix signal. + Usr2 +} + +impl fmt::Display for Sig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + Sig::Alrm => "SIGALRM", + Sig::Chld => "SIGCHLD", + Sig::Hup => "SIGHUP", + Sig::Int => "SIGINT", + Sig::Io => "SIGIO", + Sig::Pipe => "SIGPIPE", + Sig::Quit => "SIGQUIT", + Sig::Term => "SIGTERM", + Sig::Usr1 => "SIGUSR1", + Sig::Usr2 => "SIGUSR2", + }; + + s.fmt(f) + } +} diff --git a/core/lib/src/util/tripwire.rs b/core/lib/src/shutdown/tripwire.rs similarity index 77% rename from core/lib/src/util/tripwire.rs rename to core/lib/src/shutdown/tripwire.rs index c4d649bf..6e70cc7a 100644 --- a/core/lib/src/util/tripwire.rs +++ b/core/lib/src/shutdown/tripwire.rs @@ -3,6 +3,8 @@ use std::{ops::Deref, pin::Pin, future::Future}; use std::task::{Context, Poll}; use std::sync::{Arc, atomic::{AtomicBool, Ordering}}; +use futures::future::FusedFuture; +use tokio::sync::futures::Notified; use tokio::sync::Notify; #[doc(hidden)] @@ -15,7 +17,7 @@ pub struct State { pub struct TripWire { state: Arc, // `Notified` is `!Unpin`. Even if we could name it, we'd need to pin it. - event: Option + Send + Sync>>>, + event: Option>>>, } impl Deref for TripWire { @@ -35,6 +37,13 @@ impl Clone for TripWire { } } +impl Drop for TripWire { + fn drop(&mut self) { + // SAFETY: Ensure we drop the self-reference before `self`. + self.event = None; + } +} + impl fmt::Debug for TripWire { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("TripWire") @@ -47,35 +56,22 @@ impl Future for TripWire { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.tripped.load(Ordering::Acquire) { + if self.tripped() { self.event = None; return Poll::Ready(()); } if self.event.is_none() { - let state = self.state.clone(); - self.event = Some(Box::pin(async move { - let notified = state.notify.notified(); - notified.await - })); + let notified = self.state.notify.notified(); + + // SAFETY: This is a self reference to the `state`. + self.event = Some(Box::pin(unsafe { std::mem::transmute(notified) })); } if let Some(ref mut event) = self.event { - if event.as_mut().poll(cx).is_ready() { - // We need to call `trip()` to avoid a race condition where: - // 1) many trip wires have seen !self.tripped but have not - // polled for `self.event` yet, so are not subscribed - // 2) trip() is called, adding a permit to `event` - // 3) some trip wires poll `event` for the first time - // 4) one of those wins, returns `Ready()` - // 5) the rest return pending - // - // Without this `self.trip()` those will never be awoken. With - // the call to self.trip(), those that made it to poll() in 3) - // will be awoken by `notify_waiters()`. For those the didn't, - // one will be awoken by `notify_one()`, which will in-turn call - // self.trip(), awaking more until there are no more to awake. - self.trip(); + // The order here is important! We need to know: + // !self.tripped() => not notified == notified => self.tripped() + if event.as_mut().poll(cx).is_ready() || self.tripped() { self.event = None; return Poll::Ready(()); } @@ -85,6 +81,12 @@ impl Future for TripWire { } } +impl FusedFuture for TripWire { + fn is_terminated(&self) -> bool { + self.tripped() + } +} + impl TripWire { pub fn new() -> Self { TripWire { @@ -99,7 +101,6 @@ impl TripWire { pub fn trip(&self) { self.tripped.store(true, Ordering::Release); self.notify.notify_waiters(); - self.notify.notify_one(); } #[inline(always)] diff --git a/core/lib/src/tls/config.rs b/core/lib/src/tls/config.rs index 3131e16d..533b5793 100644 --- a/core/lib/src/tls/config.rs +++ b/core/lib/src/tls/config.rs @@ -427,7 +427,7 @@ impl TlsConfig { } pub fn validate(&self) -> Result<(), crate::tls::Error> { - self.acceptor().map(|_| ()) + self.server_config().map(|_| ()) } } diff --git a/core/lib/src/util/mod.rs b/core/lib/src/util/mod.rs index d3055f36..6b5df12e 100644 --- a/core/lib/src/util/mod.rs +++ b/core/lib/src/util/mod.rs @@ -1,5 +1,4 @@ mod chain; -mod tripwire; mod reader_stream; mod join; @@ -7,6 +6,55 @@ mod join; pub mod unix; pub use chain::Chain; -pub use tripwire::TripWire; pub use reader_stream::ReaderStream; pub use join::join; + +#[track_caller] +pub fn spawn_inspect(or: F, future: Fut) + where F: FnOnce(&E) + Send + Sync + 'static, + E: Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, +{ + use futures::TryFutureExt; + tokio::spawn(future.inspect_err(or)); +} + +use std::io; +use std::pin::pin; +use std::future::Future; +use futures::future::{select, Either}; + +pub trait FutureExt: Future + Sized { + /// Await `self` or `other`, whichever finishes first. + async fn or(self, other: B) -> Either { + match futures::future::select(pin!(self), pin!(other)).await { + Either::Left((v, _)) => Either::Left(v), + Either::Right((v, _)) => Either::Right(v), + } + } + + /// Await `self` unless `trigger` completes. Returns `Ok(Some(T))` if `self` + /// completes successfully before `trigger`, `Err(E)` if `self` completes + /// unsuccessfully, and `Ok(None)` if `trigger` completes before `self`. + async fn unless(self, trigger: K) -> Result, E> + where Self: Future> + { + match select(pin!(self), pin!(trigger)).await { + Either::Left((v, _)) => Ok(Some(v?)), + Either::Right((_, _)) => Ok(None), + } + } + + /// Await `self` unless `trigger` completes. If `self` completes before + /// `trigger`, returns the result. Otherwise, always returns an `Err`. + async fn io_unless(self, trigger: K) -> std::io::Result + where Self: Future> + { + match select(pin!(self), pin!(trigger)).await { + Either::Left((v, _)) => v, + Either::Right((_, _)) => Err(io::Error::other("I/O terminated")), + } + } +} + +impl FutureExt for F { } diff --git a/core/lib/tests/on_launch_fairing_can_inspect_port.rs b/core/lib/tests/on_launch_fairing_can_inspect_port.rs index 80601280..9631e57c 100644 --- a/core/lib/tests/on_launch_fairing_can_inspect_port.rs +++ b/core/lib/tests/on_launch_fairing_can_inspect_port.rs @@ -10,7 +10,8 @@ async fn on_ignite_fairing_can_inspect_port() { let rocket = rocket::custom(Config::debug_default()) .attach(AdHoc::on_liftoff("Send Port -> Channel", move |rocket| { Box::pin(async move { - tx.send(rocket.endpoint().tcp().unwrap().port()).unwrap(); + let tcp = rocket.endpoints().find_map(|v| v.tcp()); + tx.send(tcp.unwrap().port()).expect("send okay"); }) })); diff --git a/docs/guide/10-configuration.md b/docs/guide/10-configuration.md index 3865b94b..e86d8970 100644 --- a/docs/guide/10-configuration.md +++ b/docs/guide/10-configuration.md @@ -21,24 +21,24 @@ is configured with. This means that no matter which configuration provider Rocket is asked to use, it must be able to read the following configuration values: -| key | kind | description | debug/release default | -|----------------------|-------------------|-------------------------------------------------|-------------------------| -| `address` | `IpAddr` | IP address to serve on. | `127.0.0.1` | -| `port` | `u16` | Port to serve on. | `8000` | -| `workers`* | `usize` | Number of threads to use for executing futures. | cpu core count | -| `max_blocking`* | `usize` | Limit on threads to start for blocking tasks. | `512` | -| `ident` | `string`, `false` | If and how to identify via the `Server` header. | `"Rocket"` | -| `ip_header` | `string`, `false` | IP header to inspect to get [client's real IP]. | `"X-Real-IP"` | -| `proxy_proto_header` | `string`, `false` | Header identifying [client to proxy protocol]. | `None` | -| `keep_alive` | `u32` | Keep-alive timeout seconds; disabled when `0`. | `5` | -| `log_level` | [`LogLevel`] | Max level to log. (off/normal/debug/critical) | `normal`/`critical` | -| `cli_colors` | [`CliColors`] | Whether to use colors and emoji when logging. | `"auto"` | -| `secret_key` | [`SecretKey`] | Secret key for signing and encrypting values. | `None` | -| `tls` | [`TlsConfig`] | TLS configuration, if any. | `None` | -| `limits` | [`Limits`] | Streaming read size limits. | [`Limits::default()`] | -| `limits.$name` | `&str`/`uint` | Read limit for `$name`. | form = "32KiB" | -| `ctrlc` | `bool` | Whether `ctrl-c` initiates a server shutdown. | `true` | -| `shutdown`* | [`Shutdown`] | Graceful shutdown configuration. | [`Shutdown::default()`] | +| key | kind | description | debug/release default | +|----------------------|--------------------|-------------------------------------------------|-------------------------------| +| `address` | `IpAddr` | IP address to serve on. | `127.0.0.1` | +| `port` | `u16` | Port to serve on. | `8000` | +| `workers`* | `usize` | Number of threads to use for executing futures. | cpu core count | +| `max_blocking`* | `usize` | Limit on threads to start for blocking tasks. | `512` | +| `ident` | `string`, `false` | If and how to identify via the `Server` header. | `"Rocket"` | +| `ip_header` | `string`, `false` | IP header to inspect to get [client's real IP]. | `"X-Real-IP"` | +| `proxy_proto_header` | `string`, `false` | Header identifying [client to proxy protocol]. | `None` | +| `keep_alive` | `u32` | Keep-alive timeout seconds; disabled when `0`. | `5` | +| `log_level` | [`LogLevel`] | Max level to log. (off/normal/debug/critical) | `normal`/`critical` | +| `cli_colors` | [`CliColors`] | Whether to use colors and emoji when logging. | `"auto"` | +| `secret_key` | [`SecretKey`] | Secret key for signing and encrypting values. | `None` | +| `tls` | [`TlsConfig`] | TLS configuration, if any. | `None` | +| `limits` | [`Limits`] | Streaming read size limits. | [`Limits::default()`] | +| `limits.$name` | `&str`/`uint` | Read limit for `$name`. | form = "32KiB" | +| `ctrlc` | `bool` | Whether `ctrl-c` initiates a server shutdown. | `true` | +| `shutdown`* | [`ShutdownConfig`] | Graceful shutdown configuration. | [`ShutdownConfig::default()`] | * Note: the `workers`, `max_blocking`, and `shutdown.force` configuration @@ -77,8 +77,8 @@ profile supplant any values with the same name in any profile. [`SecretKey`]: @api/master/rocket/config/struct.SecretKey.html [`CliColors`]: @api/master/rocket/config/enum.CliColors.html [`TlsConfig`]: @api/master/rocket/tls/struct.TlsConfig.html -[`Shutdown`]: @api/master/rocket/config/struct.Shutdown.html -[`Shutdown::default()`]: @api/master/rocket/config/struct.Shutdown.html#fields +[`ShutdownConfig`]: @api/master/rocket/shutdown/struct.ShutdownConfig.html +[`ShutdownConfig::default()`]: @api/master/rocket/shutdown/struct.ShutdownConfig.html#fields ## Default Provider diff --git a/examples/tls/Cargo.toml b/examples/tls/Cargo.toml index b37758ff..27e5b61d 100644 --- a/examples/tls/Cargo.toml +++ b/examples/tls/Cargo.toml @@ -6,5 +6,5 @@ edition = "2021" publish = false [dependencies] -rocket = { path = "../../core/lib", features = ["tls", "mtls", "secrets"] } +rocket = { path = "../../core/lib", features = ["tls", "mtls", "secrets", "http3-preview"] } yansi = "1.0.1" diff --git a/examples/tls/src/main.rs b/examples/tls/src/main.rs index 4ce4254c..7147e7fa 100644 --- a/examples/tls/src/main.rs +++ b/examples/tls/src/main.rs @@ -1,10 +1,12 @@ -#[macro_use] extern crate rocket; +#[macro_use] +extern crate rocket; #[cfg(test)] mod tests; mod redirector; use rocket::mtls::Certificate; +use rocket::listener::Endpoint; #[get("/")] fn mutual(cert: Certificate<'_>) -> String { @@ -12,8 +14,11 @@ fn mutual(cert: Certificate<'_>) -> String { } #[get("/", rank = 2)] -fn hello() -> &'static str { - "Hello, world!" +fn hello(endpoint: Option<&Endpoint>) -> String { + match endpoint { + Some(endpoint) => format!("Hello, {endpoint}!"), + None => "Hello, world!".into(), + } } #[launch] diff --git a/examples/tls/src/redirector.rs b/examples/tls/src/redirector.rs index e490ee1b..fb42ac7b 100644 --- a/examples/tls/src/redirector.rs +++ b/examples/tls/src/redirector.rs @@ -74,7 +74,7 @@ impl Fairing for Redirector { } async fn on_liftoff(&self, rocket: &Rocket) { - let Some(tls_addr) = rocket.endpoint().tls().and_then(|tls| tls.tcp()) else { + let Some(tls_addr) = rocket.endpoints().find_map(|e| e.tls()?.tcp()) else { info!("{}{}", "🔒 ".mask(), "HTTP -> HTTPS Redirector:".magenta()); warn_!("Main instance is not being served over TLS/TCP."); warn_!("Redirector refusing to start."); diff --git a/scripts/test.sh b/scripts/test.sh index 43468022..d4762ab4 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -128,6 +128,7 @@ function test_core() { FEATURES=( tokio-macros http2 + http3-preview secrets tls mtls