From 1516ca4fb6f475b7e296ddc365d1bd2accad5873 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Thu, 13 Apr 2017 00:18:31 -0700 Subject: [PATCH] Initial TLS support. This commit introduces TLS support, provided by `rustls` and a fork of `hyper-rustls`. TLS support is enabled via the `tls` feature and activated when the `tls` configuration parameter is set. A new `hello_tls` example illustrates its usage. This commit also introduces more robust and complete configuration settings via environment variables. In particular, quoted string, array, and table (dictionaries) based configuration parameters can now be set via environment variables. Resolves #28. --- Cargo.toml | 1 + examples/hello_tls/Cargo.toml | 11 ++ examples/hello_tls/Rocket.toml | 11 ++ examples/hello_tls/private/cert.pem | 37 +++++ examples/hello_tls/private/key.pem | 51 +++++++ examples/hello_tls/src/main.rs | 15 ++ examples/hello_tls/src/tests.rs | 13 ++ lib/Cargo.toml | 8 ++ lib/src/config/builder.rs | 27 ++++ lib/src/config/config.rs | 216 +++++++++++++++++++--------- lib/src/config/error.rs | 19 ++- lib/src/config/mod.rs | 118 +++++++++++++-- lib/src/config/toml_ext.rs | 113 +++++++++++++-- lib/src/data/data.rs | 65 +++++---- lib/src/data/data_stream.rs | 86 ++++++++++- lib/src/error.rs | 5 + lib/src/lib.rs | 3 + lib/src/logger.rs | 10 +- lib/src/rocket.rs | 75 ++++++++-- 19 files changed, 739 insertions(+), 145 deletions(-) create mode 100644 examples/hello_tls/Cargo.toml create mode 100644 examples/hello_tls/Rocket.toml create mode 100644 examples/hello_tls/private/cert.pem create mode 100644 examples/hello_tls/private/key.pem create mode 100644 examples/hello_tls/src/main.rs create mode 100644 examples/hello_tls/src/tests.rs diff --git a/Cargo.toml b/Cargo.toml index 47d91dbe..1f0af5ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,4 +33,5 @@ members = [ "examples/uuid", "examples/session", "examples/raw_sqlite", + "examples/hello_tls", ] diff --git a/examples/hello_tls/Cargo.toml b/examples/hello_tls/Cargo.toml new file mode 100644 index 00000000..d691440e --- /dev/null +++ b/examples/hello_tls/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "hello_tls" +version = "0.0.0" +workspace = "../../" + +[dependencies] +rocket = { path = "../../lib", features = ["tls"] } +rocket_codegen = { path = "../../codegen" } + +[dev-dependencies] +rocket = { path = "../../lib", features = ["testing"] } diff --git a/examples/hello_tls/Rocket.toml b/examples/hello_tls/Rocket.toml new file mode 100644 index 00000000..14c7fa63 --- /dev/null +++ b/examples/hello_tls/Rocket.toml @@ -0,0 +1,11 @@ +# The certificate/private key pair used here was generated via openssl: +# +# openssl req -x509 -newkey rsa:4096 -nodes -sha256 -days 3650 \ +# -keyout key.pem -out cert.pem +# +# The certificate is self-signed. As such, you will need to trust it directly +# for your browser to refer to the connection as secure. You should NEVER use +# this certificate/key pair. It is here for DEMONSTRATION PURPOSES ONLY. +[global.tls] +certs = "private/cert.pem" +key = "private/key.pem" diff --git a/examples/hello_tls/private/cert.pem b/examples/hello_tls/private/cert.pem new file mode 100644 index 00000000..3019e697 --- /dev/null +++ b/examples/hello_tls/private/cert.pem @@ -0,0 +1,37 @@ +-----BEGIN CERTIFICATE----- +MIIGXjCCBEagAwIBAgIJAJBTO2YLMz4tMA0GCSqGSIb3DQEBCwUAMHwxCzAJBgNV +BAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMREwDwYDVQQHEwhTdGFuZm9yZDEP +MA0GA1UEChMGUm9ja2V0MRcwFQYDVQQDEw5TZXJnaW8gQmVuaXRlejEbMBkGCSqG +SIb3DQEJARYMc2JAcm9ja2V0LnJzMB4XDTE3MDQwOTA0MTIxM1oXDTI3MDQwNzA0 +MTIxM1owfDELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExETAPBgNV +BAcTCFN0YW5mb3JkMQ8wDQYDVQQKEwZSb2NrZXQxFzAVBgNVBAMTDlNlcmdpbyBC +ZW5pdGV6MRswGQYJKoZIhvcNAQkBFgxzYkByb2NrZXQucnMwggIiMA0GCSqGSIb3 +DQEBAQUAA4ICDwAwggIKAoICAQCzdm0ZxLNP4TlJBI2IpeVT4S6hZeBkem/aj4NZ +mhHA06HXVqcUw3W03YQklhO7E305uU/BTRz5q0BIa2DCPyZDUCkwTjOZAuFiiZzc +AZz/zhu2RwLWeYttlvjKewrIe0k9zrPaPXpdcFe0xq2mcUon0fyRztL1H8EYEScb +/TJqM1LkWKGSJEOMDeEYMVnJn/x9yFgfC82u/4GBc3q3Si2uRLCMkTLsg6TC27EF +kCVuOISf1+CvAKgk2x29SGm/nYoTe+j6YLm12h41S6JlGO9zJnORlwb4Mz5h+72p +NBaVER72kNxwskTNg2IWur7NM2Xi/nAfZ7+YOopgwosRuZl8Nw6CcpWDkGdLnO7X +H18Wy/BXOamXVa65tWefwlCiJ8bkqZgik8AHX36KZzTzkDO5g/4JAQDh4G56paGu +hcd1LXkGvTDuaSN4BkHDuYucr89aliWV/AKzum4BJkyKk3lVWDb9nfwyTRegsZg5 +ipTW7xLhvxzjeoLuDHybRzsw+2NFQoHA4PUouzC1n2/+eJIVysa6p5UZXTcTNGVd +rdU3GmifpFDBv4NwQrQ1y2izw0b+dbZ7DBAQIqW3toHeBUmmTiHSmQR5QT3Dz9HA +l2npMu4S2ZKQYJj+zqxyETzrOgz76LW1yZ3uAbX7z0OxlOoC67XYGAAWlDyU4pZc +qcnR1QIDAQABo4HiMIHfMB0GA1UdDgQWBBSUdwqW1sNXbeS29wXaJL0P9glPmTCB +rwYDVR0jBIGnMIGkgBSUdwqW1sNXbeS29wXaJL0P9glPmaGBgKR+MHwxCzAJBgNV +BAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMREwDwYDVQQHEwhTdGFuZm9yZDEP +MA0GA1UEChMGUm9ja2V0MRcwFQYDVQQDEw5TZXJnaW8gQmVuaXRlejEbMBkGCSqG +SIb3DQEJARYMc2JAcm9ja2V0LnJzggkAkFM7ZgszPi0wDAYDVR0TBAUwAwEB/zAN +BgkqhkiG9w0BAQsFAAOCAgEAXfsz3a1iDWoyyobVlFV3NvVD5CeZ9oWh/IvmbgfB +XMEB3ZLy3Cqn4op1u6Kbo/L8E+YYMlIGdqUtAlpJrtF1k3KeVGBScx6YebCPWcuL +aRO/l1qUR70RhA/Yrz0iNqcTSkC2n9YYFtr7tOiTzS3kqN03XB6fJBsYnVG4LzR5 +1EdNjgGoISSVmKJwQh0Sjy+GuHDnsjtL5xPxf5OFA6bgJiYkpMgxv0VDzC3Bl6NL +oTYwnQ+/19yzoSZANlvwKi8UftHEpBXMAW2Yr3jEfKuSQIe3FPr2js2JWfpl12yQ +JZnDPEJOamxD4hvvWljENwcVMss9X9FGiQCoFIYmGja4JXZ7KqvOlgdSaS8TUqCp +qHcSJpEiJQAJQton607EjWxBBWVEMEQYZx5nLFifxwexxv1jpbGeh+ehAwRlvrZU +nXR9miv/ohw6HmopNXmXcTCJsT8/OHb4g7cUs3scUmuySMZe1dKht+o+XmkWfF9b +fgqNz9so1ls9oyg/qjuMwh5wNNUsPQJGITmzTOfGGu7engyil744flO5aQFfSwcm +zQ7ZzRh+jDPI/rG8xbrYpXXK3+xln03O/96AC6iEELA0+A2PeworEkz47nEPrqlC +Fr17Aya+rJsrN9JXL1Uz87k3XfySNc6xT8zitNGzgxtgKthUg6fU9oOt/79HuOeL +g+Y= +-----END CERTIFICATE----- diff --git a/examples/hello_tls/private/key.pem b/examples/hello_tls/private/key.pem new file mode 100644 index 00000000..13ee4c54 --- /dev/null +++ b/examples/hello_tls/private/key.pem @@ -0,0 +1,51 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIJKQIBAAKCAgEAs3ZtGcSzT+E5SQSNiKXlU+EuoWXgZHpv2o+DWZoRwNOh11an +FMN1tN2EJJYTuxN9OblPwU0c+atASGtgwj8mQ1ApME4zmQLhYomc3AGc/84btkcC +1nmLbZb4ynsKyHtJPc6z2j16XXBXtMatpnFKJ9H8kc7S9R/BGBEnG/0yajNS5Fih +kiRDjA3hGDFZyZ/8fchYHwvNrv+BgXN6t0otrkSwjJEy7IOkwtuxBZAlbjiEn9fg +rwCoJNsdvUhpv52KE3vo+mC5tdoeNUuiZRjvcyZzkZcG+DM+Yfu9qTQWlREe9pDc +cLJEzYNiFrq+zTNl4v5wH2e/mDqKYMKLEbmZfDcOgnKVg5BnS5zu1x9fFsvwVzmp +l1WuubVnn8JQoifG5KmYIpPAB19+imc085AzuYP+CQEA4eBueqWhroXHdS15Br0w +7mkjeAZBw7mLnK/PWpYllfwCs7puASZMipN5VVg2/Z38Mk0XoLGYOYqU1u8S4b8c +43qC7gx8m0c7MPtjRUKBwOD1KLswtZ9v/niSFcrGuqeVGV03EzRlXa3VNxpon6RQ +wb+DcEK0Nctos8NG/nW2ewwQECKlt7aB3gVJpk4h0pkEeUE9w8/RwJdp6TLuEtmS +kGCY/s6schE86zoM++i1tcmd7gG1+89DsZTqAuu12BgAFpQ8lOKWXKnJ0dUCAwEA +AQKCAgBIdkLrKq80S75zqzDywfls+vl3FcmbCIztdREWNs2ATHOGnWhtS9bVJrRa +iXaCDQZ9LkPzyw0uCmW0WBcDl7f9afqXlJvk5nLW9LWvZ79a0oACA34z13Pi1hiy +uSfLd2xFVpbsQfKMk/X1+lrXX9sPZQxUW2x2qVGwRAzEkmGu2/ZWWSsz9QyJGnmO +6S5V6RFsQF7EemGcjXJfMJ+WLo9vVDDtMRucwDLgsxAxLNjQPmXenK4OO3epGghS +C1EXm6bK4zdZEYEq2l1kK5vwsjbNCfOUD6Uyxo4jxh/4mB2eJwGXkTpRDsoVKT2L +6+9qr5wuIYpoQ93qu4hwNV0t1QERp4HYTrX6WzDPLCtoHlIfCfzXH05jYa9n/nJD +Ow4eeeK5RE7/9/fKJWX2/R45iz87QKeS2H+ps3IOirK1P4Y/4FdIoHVyHjaCeoGp +YeSXjTWz6OEcHX9qcShdVu8ILkJlBJ4xftxNKWw8d/jKG3xOaxVNAnTGGAx113Pa +RIZcGNfroQcceV0mRAQTTpZyV7jj+lvC9lXqyP4tJJc4YjSWG6ITPrYfKVYBAhNC +K8MutrCaEH87SzIssY3DvpSzIeZMafvR4VYHZbBsDpI8+9iIijXLzDR0IQUANcah +bLWzjhDXrE5f1Et1hyHCAkCzsPIyx67QugYpJ8zTDUp6oYVaaQKCAQEA3XjPYrAJ +kokPk8ErDYTe0lMRGKvIrV3TH5hkfTIlkMfn4M6wVS3WQYYEMwgcJHVILpHkxtuD +ZcNIrey5f3IlwvSGo30w9L4IQCt5nzU0GpYqeN4Uty+0jDO0pygPYRrNAedSQSYa +jgSwoQCiu/qVzb+fPD2n6Hfxn/+cBP/mw6JBZZRb4aIhAluIt958J7NB/AwY+o03 +rG6lq41UUYDtydt/yP3q7QlRBFkY+anm7O7Iy80e4oPciXbettsMuHPa8/zF1TWJ +IxG2v9C4Q++tvj9zKP66l/5JVKVdtexGtPJS9i8lIkPkqX7K9peOUDNwyVAyG4kB +2SxKysu4sTq5ZwKCAQEAz3D6UTbTnzkrQFDnKaYUpMW4jfKUFr/n+s5iIN3FPk9t +ZuZ3YbrIroW50KxmVYFG+x1vyEZvB+Bdb7apCnqK4z2x4vbpl4bsE/uW8Kk+rQxx +gbUfuih1r0FMVE0kGL8MxJEL/4owOZM9G8hinxkVst4LswNe0MPwfJHgeLrx/xAM +lHq4hS5Pb4SYb+Z6iUJlzJpsQPX+JLW3cDqlfUBB9ckijtNkTbwsq7ETWRglnAP1 +flOCBe6l1aSDPiSa3fRSwA0PHTztVZt2TwnhDD2v40tTxCfeAwafeEwpUUgWivt0 +Doq0Tni5lZiPjqNfmbZSa8BKyeEtPuKZcKT4QMSJYwKCAQA8+TTHa8XG5Rs3x5fN +ygX6i8oKK8k9CbbFXRRVb4fuG0tYli7v1IXHVlkzn4j39J4hzCLbKLY9Pw10bNcJ +ImkJCn9C5YWj6+mjmRSL437rzun0itfTMzwW2WlkF+BcEJ/eZUw9CXuIG/xw5xbm +f+/cTGRPln3yv4rzTNEsgzOKKtKsX7MIJLXHy2GRlZxC5dRFyyLZYCWywGe2Glvb +cI6G43qD4HxcNBNtCgaZPdCI7Ji1m0xken8uDV71ossWwTbHs5DXyTxvPkI8/v6s +HYGM/jT7VV4T2Hth5YEuQ9WXnZt/ka08iMqca37/cuxIYlEr63tQH2E15D7XJE09 +5fgDAoIBAQDKC2hDofsMokoWIraEQlbpBgtzdkn2voPcLRg2msp6njIYf3DXp22/ +TlBlhwVFUt0nyMwPbUrHiSh4npiWtDSCkJyqS4PJKojWDb4+ORnqwqvrgdadIrs9 +L4SAt4Ho+GwfKIdfJeFCsr5aSRqFi5Eu3kbW3PmErNOXAR55eNwrah5WoBEI5spH +/AXdN8cx2ZH9borx2qbmand4wCZfkC6ujnEyW4Lek+GOeLI3nOVEyDZcDEogLQko +xUtvQ4fzlvziQdXuzGD9eKYK5bxkh9DAuaWk8I+0ssawDL5RhL0wMSog38guhjd8 +FVP9wfJjbMlqWah+aOwAzARXStbhfouxAoIBAQCw0HQEHJL0nTLT0SdkEAZF12Fo +NMdTh68xtQ1y2papT87L6J8/WmK1/O32KAJ3XXikl0QMTkhjHqf+eA+27C8L+jIR +HPhB4OGdlOufk1QXoaX1Z4vVHfyyXASspfZ1ecxFsQdC/lPnQ+ir/skHI1NNC9oO +Q39d37tyoKCOyZUhD0IsT5+6vVgyj6EcwiCmgwZ7PI3MKKoXx7HmkZ9e4hVYbXP6 +WsNsF2VKPHCre56T2FRT3xLxzN5uZOz0Hrau50Y/3RNuYiJJ2aufAYmv1r6HT/0W +BP2kmzmWlg1JJaU9vS9jZXabQZPp8bJ+fDtSZSa69AJESKRE9e3e895uW3gY +-----END RSA PRIVATE KEY----- diff --git a/examples/hello_tls/src/main.rs b/examples/hello_tls/src/main.rs new file mode 100644 index 00000000..524574ad --- /dev/null +++ b/examples/hello_tls/src/main.rs @@ -0,0 +1,15 @@ +#![feature(plugin)] +#![plugin(rocket_codegen)] + +extern crate rocket; + +#[cfg(test)] mod tests; + +#[get("/")] +fn hello() -> &'static str { + "Hello, world!" +} + +fn main() { + rocket::ignite().mount("/", routes![hello]).launch(); +} diff --git a/examples/hello_tls/src/tests.rs b/examples/hello_tls/src/tests.rs new file mode 100644 index 00000000..94b54e03 --- /dev/null +++ b/examples/hello_tls/src/tests.rs @@ -0,0 +1,13 @@ +use super::rocket; +use rocket::testing::MockRequest; +use rocket::http::Method::*; + +#[test] +fn hello_world() { + let rocket = rocket::ignite().mount("/", routes![super::hello]); + let mut req = MockRequest::new(Get, "/"); + let mut response = req.dispatch_with(&rocket); + + let body_str = response.body().and_then(|body| body.into_string()); + assert_eq!(body_str, Some("Hello, world!".to_string())); +} diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 1534b046..2f998332 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -16,6 +16,7 @@ categories = ["web-programming::http-server"] [features] testing = [] +tls = ["rustls", "hyper-rustls"] [dependencies] term-painter = "0.2" @@ -31,6 +32,13 @@ base64 = "0.4" smallvec = { git = "https://github.com/SergioBenitez/rust-smallvec" } pear = "0.0.8" pear_codegen = "0.0.8" +rustls = { version = "0.5.8", optional = true } + +[dependencies.hyper-rustls] +git = "https://github.com/SergioBenitez/hyper-rustls" +default-features = false +features = ["server"] +optional = true [dependencies.cookie] version = "0.7.4" diff --git a/lib/src/config/builder.rs b/lib/src/config/builder.rs index 18647b37..251ef98b 100644 --- a/lib/src/config/builder.rs +++ b/lib/src/config/builder.rs @@ -20,6 +20,8 @@ pub struct ConfigBuilder { pub log_level: LoggingLevel, /// The session key. pub session_key: Option, + /// TLS configuration (path to certificates file, path to private key file). + pub tls_config: Option<(String, String)>, /// Any extra parameters that aren't part of Rocket's config. pub extras: HashMap, /// The root directory of this config. @@ -63,6 +65,7 @@ impl ConfigBuilder { workers: config.workers, log_level: config.log_level, session_key: None, + tls_config: None, extras: config.extras, root: root_dir, } @@ -162,6 +165,26 @@ impl ConfigBuilder { self } + /// Sets the `tls_config` in the configuration being built. + /// + /// # Example + /// + /// ```rust + /// use rocket::config::{Config, Environment}; + /// + /// let mut config = Config::build(Environment::Staging) + /// .tls("/path/to/certs.pem", "/path/to/key.pem") + /// # ; /* + /// .unwrap(); + /// # */ + /// ``` + pub fn tls(mut self, certs_path: C, key_path: K) -> Self + where C: Into, K: Into + { + self.tls_config = Some((certs_path.into(), key_path.into())); + self + } + /// Sets the `environment` in the configuration being built. /// /// # Example @@ -260,6 +283,10 @@ impl ConfigBuilder { config.set_extras(self.extras); config.set_root(self.root); + if let Some((certs_path, key_path)) = self.tls_config { + config.set_tls(&certs_path, &key_path)?; + } + if let Some(key) = self.session_key { config.set_session_key(key)?; } diff --git a/lib/src/config/config.rs b/lib/src/config/config.rs index f6087e1c..8f6b3d51 100644 --- a/lib/src/config/config.rs +++ b/lib/src/config/config.rs @@ -5,10 +5,11 @@ use std::convert::AsRef; use std::fmt; use std::env; -use config::Environment::*; -use config::{self, Value, ConfigBuilder, Environment, ConfigError}; +#[cfg(feature = "tls")] use rustls::{Certificate, PrivateKey}; use {num_cpus, base64}; +use config::Environment::*; +use config::{Result, Table, Value, ConfigBuilder, Environment, ConfigError}; use logger::LoggingLevel; use http::Key; @@ -18,7 +19,7 @@ pub enum SessionKey { } impl SessionKey { - #[inline] + #[inline(always)] pub fn kind(&self) -> &'static str { match *self { SessionKey::Generated(_) => "generated", @@ -26,7 +27,7 @@ impl SessionKey { } } - #[inline] + #[inline(always)] fn inner(&self) -> &Key { match *self { SessionKey::Generated(ref key) | SessionKey::Provided(ref key) => key @@ -34,6 +35,15 @@ impl SessionKey { } } +#[cfg(feature = "tls")] +pub struct TlsConfig { + pub certs: Vec, + pub key: PrivateKey +} + +#[cfg(not(feature = "tls"))] +pub struct TlsConfig; + /// Structure for Rocket application configuration. /// /// A `Config` structure is typically built using the [build](#method.build) @@ -61,20 +71,73 @@ pub struct Config { pub workers: u16, /// How much information to log. pub log_level: LoggingLevel, + /// The session key. + pub(crate) session_key: SessionKey, + /// TLS configuration. + pub(crate) tls: Option, /// Extra parameters that aren't part of Rocket's core config. pub extras: HashMap, /// The path to the configuration file this config belongs to. pub config_path: PathBuf, - /// The session key. - pub(crate) session_key: SessionKey, } -macro_rules! parse { - ($conf:expr, $name:expr, $val:expr, $method:ident, $expect: expr) => ( - $val.$method().ok_or_else(|| { - $conf.bad_type($name, $val.type_str(), $expect) - }) - ); +macro_rules! config_from_raw { + ($config:expr, $name:expr, $value:expr, + $($key:ident => ($type:ident, $set:ident, $map:expr)),+ | _ => $rest:expr) => ( + match $name { + $(stringify!($key) => { + concat_idents!(value_as_, $type)($config, $name, $value) + .and_then(|parsed| $map($config.$set(parsed))) + })+ + _ => $rest + } + ) +} + +#[inline(always)] +fn value_as_str<'a>(config: &Config, name: &str, value: &'a Value) -> Result<&'a str> { + value.as_str().ok_or(config.bad_type(name, value.type_str(), "a string")) +} + +#[inline(always)] +fn value_as_u16(config: &Config, name: &str, value: &Value) -> Result { + match value.as_integer() { + Some(x) if x >= 0 && x <= (u16::max_value() as i64) => Ok(x as u16), + _ => Err(config.bad_type(name, value.type_str(), "a 16-bit unsigned integer")) + } +} + +#[inline(always)] +fn value_as_log_level(config: &Config, name: &str, value: &Value) -> Result { + value_as_str(config, name, value) + .and_then(|s| s.parse().map_err(|e| config.bad_type(name, value.type_str(), e))) +} + +#[inline(always)] +fn value_as_tls_config<'v>(config: &Config, + name: &str, + value: &'v Value, + ) -> Result<(&'v str, &'v str)> +{ + let (mut certs_path, mut key_path) = (None, None); + let table = value.as_table() + .ok_or_else(|| config.bad_type(name, value.type_str(), "a table"))?; + + let env = config.environment; + for (key, value) in table { + match key.as_str() { + "certs" => certs_path = Some(value_as_str(config, "tls.certs", value)?), + "key" => key_path = Some(value_as_str(config, "tls.key", value)?), + _ => return Err(ConfigError::UnknownKey(format!("{}.tls.{}", env, key))) + } + } + + if let (Some(certs), Some(key)) = (certs_path, key_path) { + Ok((certs, key)) + } else { + Err(config.bad_type(name, "a table with missing entries", + "a table with `certs` and `key` entries")) + } } impl Config { @@ -119,7 +182,7 @@ impl Config { /// let mut my_config = Config::new(Environment::Production).expect("cwd"); /// my_config.set_port(1001); /// ``` - pub fn new(env: Environment) -> config::Result { + pub fn new(env: Environment) -> Result { let cwd = env::current_dir().map_err(|_| ConfigError::BadCWD)?; Config::default(env, cwd.as_path().join("Rocket.custom.toml")) } @@ -131,7 +194,7 @@ impl Config { /// # Panics /// /// Panics if randomness cannot be retrieved from the OS. - pub(crate) fn default

(env: Environment, path: P) -> config::Result + pub(crate) fn default

(env: Environment, path: P) -> Result where P: AsRef { let config_path = path.as_ref().to_path_buf(); @@ -155,6 +218,7 @@ impl Config { workers: default_workers, log_level: LoggingLevel::Normal, session_key: key, + tls: None, extras: HashMap::new(), config_path: config_path, } @@ -167,6 +231,7 @@ impl Config { workers: default_workers, log_level: LoggingLevel::Normal, session_key: key, + tls: None, extras: HashMap::new(), config_path: config_path, } @@ -179,6 +244,7 @@ impl Config { workers: default_workers, log_level: LoggingLevel::Critical, session_key: key, + tls: None, extras: HashMap::new(), config_path: config_path, } @@ -209,39 +275,21 @@ impl Config { /// * **workers**: Integer (16-bit unsigned) /// * **log**: String /// * **session_key**: String (192-bit base64) - pub(crate) fn set_raw(&mut self, name: &str, val: &Value) -> config::Result<()> { - if name == "address" { - let address_str = parse!(self, name, val, as_str, "a string")?; - self.set_address(address_str)?; - } else if name == "port" { - let port = parse!(self, name, val, as_integer, "an integer")?; - if port < 0 || port > (u16::max_value() as i64) { - return Err(self.bad_type(name, val.type_str(), "a 16-bit unsigned integer")) + /// * **tls**: Table (`certs` (path as String), `key` (path as String)) + pub(crate) fn set_raw(&mut self, name: &str, val: &Value) -> Result<()> { + let (id, ok) = (|val| val, |_| Ok(())); + config_from_raw!(self, name, val, + address => (str, set_address, id), + port => (u16, set_port, ok), + workers => (u16, set_workers, ok), + session_key => (str, set_session_key, id), + log => (log_level, set_log_level, ok), + tls => (tls_config, set_raw_tls, id) + | _ => { + self.extras.insert(name.into(), val.clone()); + Ok(()) } - - self.set_port(port as u16); - } else if name == "workers" { - let workers = parse!(self, name, val, as_integer, "an integer")?; - if workers < 0 || workers > (u16::max_value() as i64) { - return Err(self.bad_type(name, val.type_str(), "a 16-bit unsigned integer")); - } - - self.set_workers(workers as u16); - } else if name == "session_key" { - let key = parse!(self, name, val, as_str, "a string")?; - self.set_session_key(key)?; - } else if name == "log" { - let level_str = parse!(self, name, val, as_str, "a string")?; - let expect = "log level ('normal', 'critical', 'debug')"; - match level_str.parse() { - Ok(level) => self.set_log_level(level), - Err(_) => return Err(self.bad_type(name, val.type_str(), expect)) - } - } else { - self.extras.insert(name.into(), val.clone()); - } - - Ok(()) + ) } /// Sets the root directory of this configuration to `root`. @@ -286,7 +334,7 @@ impl Config { /// # Ok(()) /// # } /// ``` - pub fn set_address>(&mut self, address: A) -> config::Result<()> { + pub fn set_address>(&mut self, address: A) -> Result<()> { let address = address.into(); if address.parse::().is_err() && lookup_host(&address).is_err() { return Err(self.bad_type("address", "string", "a valid hostname or IP")); @@ -310,6 +358,7 @@ impl Config { /// # Ok(()) /// # } /// ``` + #[inline] pub fn set_port(&mut self, port: u16) { self.port = port; } @@ -328,6 +377,7 @@ impl Config { /// # Ok(()) /// # } /// ``` + #[inline] pub fn set_workers(&mut self, workers: u16) { self.workers = workers; } @@ -354,7 +404,7 @@ impl Config { /// # Ok(()) /// # } /// ``` - pub fn set_session_key>(&mut self, key: K) -> config::Result<()> { + pub fn set_session_key>(&mut self, key: K) -> Result<()> { let key = key.into(); let error = self.bad_type("session_key", "string", "a 256-bit base64 encoded string"); @@ -387,10 +437,42 @@ impl Config { /// # Ok(()) /// # } /// ``` + #[inline] pub fn set_log_level(&mut self, log_level: LoggingLevel) { self.log_level = log_level; } + #[cfg(feature = "tls")] + pub fn set_tls(&mut self, certs_path: &str, key_path: &str) -> Result<()> { + use hyper_rustls::util as tls; + + let err = "nonexistent or invalid file"; + let certs = tls::load_certs(certs_path) + .map_err(|_| self.bad_type("tls", err, "a readable certificates file"))?; + let key = tls::load_private_key(key_path) + .map_err(|_| self.bad_type("tls", err, "a readable private key file"))?; + + self.tls = Some(TlsConfig { certs, key }); + Ok(()) + } + + #[cfg(not(feature = "tls"))] + pub fn set_tls(&mut self, _: &str, _: &str) -> Result<()> { + self.tls = Some(TlsConfig); + Ok(()) + } + + #[cfg(not(test))] + #[inline(always)] + fn set_raw_tls(&mut self, paths: (&str, &str)) -> Result<()> { + self.set_tls(paths.0, paths.1) + } + + #[cfg(test)] + fn set_raw_tls(&mut self, _: (&str, &str)) -> Result<()> { + Ok(()) + } + /// Sets the extras for `self` to be the key/value pairs in `extras`. /// encoded string. /// @@ -413,6 +495,7 @@ impl Config { /// # Ok(()) /// # } /// ``` + #[inline] pub fn set_extras(&mut self, extras: HashMap) { self.extras = extras; } @@ -441,6 +524,7 @@ impl Config { /// # Ok(()) /// # } /// ``` + #[inline] pub fn extras<'a>(&'a self) -> impl Iterator { self.extras.iter().map(|(k, v)| (k.as_str(), v)) } @@ -470,9 +554,9 @@ impl Config { /// /// assert_eq!(config.get_str("my_extra"), Ok("extra_value")); /// ``` - pub fn get_str<'a>(&'a self, name: &str) -> config::Result<&'a str> { - let value = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; - parse!(self, name, value, as_str, "a string") + pub fn get_str<'a>(&'a self, name: &str) -> Result<&'a str> { + let val = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; + val.as_str().ok_or_else(|| self.bad_type(name, val.type_str(), "a string")) } /// Attempts to retrieve the extra named `name` as an integer. @@ -494,9 +578,9 @@ impl Config { /// /// assert_eq!(config.get_int("my_extra"), Ok(1025)); /// ``` - pub fn get_int(&self, name: &str) -> config::Result { - let value = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; - parse!(self, name, value, as_integer, "an integer") + pub fn get_int(&self, name: &str) -> Result { + let val = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; + val.as_integer().ok_or_else(|| self.bad_type(name, val.type_str(), "an integer")) } /// Attempts to retrieve the extra named `name` as a boolean. @@ -518,9 +602,9 @@ impl Config { /// /// assert_eq!(config.get_bool("my_extra"), Ok(true)); /// ``` - pub fn get_bool(&self, name: &str) -> config::Result { - let value = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; - parse!(self, name, value, as_bool, "a boolean") + pub fn get_bool(&self, name: &str) -> Result { + let val = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; + val.as_bool().ok_or_else(|| self.bad_type(name, val.type_str(), "a boolean")) } /// Attempts to retrieve the extra named `name` as a float. @@ -542,9 +626,9 @@ impl Config { /// /// assert_eq!(config.get_float("pi"), Ok(3.14159)); /// ``` - pub fn get_float(&self, name: &str) -> config::Result { - let value = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; - parse!(self, name, value, as_float, "a float") + pub fn get_float(&self, name: &str) -> Result { + let val = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; + val.as_float().ok_or_else(|| self.bad_type(name, val.type_str(), "a float")) } /// Attempts to retrieve the extra named `name` as a slice of an array. @@ -566,9 +650,9 @@ impl Config { /// /// assert!(config.get_slice("numbers").is_ok()); /// ``` - pub fn get_slice(&self, name: &str) -> config::Result<&[Value]> { - let value = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; - parse!(self, name, value, as_slice, "a slice") + pub fn get_slice(&self, name: &str) -> Result<&[Value]> { + let val = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; + val.as_slice().ok_or_else(|| self.bad_type(name, val.type_str(), "a slice")) } /// Attempts to retrieve the extra named `name` as a table. @@ -594,9 +678,9 @@ impl Config { /// /// assert!(config.get_table("my_table").is_ok()); /// ``` - pub fn get_table(&self, name: &str) -> config::Result<&config::Table> { - let value = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; - parse!(self, name, value, as_table, "a table") + pub fn get_table(&self, name: &str) -> Result<&Table> { + let val = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; + val.as_table().ok_or_else(|| self.bad_type(name, val.type_str(), "a table")) } /// Returns the path at which the configuration file for `self` is stored. diff --git a/lib/src/config/error.rs b/lib/src/config/error.rs index 8a814de2..01179d2f 100644 --- a/lib/src/config/error.rs +++ b/lib/src/config/error.rs @@ -52,8 +52,12 @@ pub enum ConfigError { ParseError(String, PathBuf, Vec), /// There was a TOML parsing error in a config environment variable. /// - /// Parameters: (env_key, env_value, expected type) - BadEnvVal(String, String, &'static str), + /// Parameters: (env_key, env_value, error) + BadEnvVal(String, String, String), + /// The entry (key) is unknown. + /// + /// Parameters: (key) + UnknownKey(String), } impl ConfigError { @@ -95,11 +99,14 @@ impl ConfigError { trace_!("'{}' - {}", error_source, White.paint(&error.desc)); } } - BadEnvVal(ref key, ref value, ref expected) => { + BadEnvVal(ref key, ref value, ref error) => { error!("environment variable '{}={}' could not be parsed", White.paint(key), White.paint(value)); - info_!("value for {:?} must be {}", - White.paint(key), White.paint(expected)) + info_!("{}", White.paint(error)); + } + UnknownKey(ref key) => { + error!("the configuration key '{}' is unknown and disallowed in \ + this position", White.paint(key)); } } } @@ -123,6 +130,7 @@ impl fmt::Display for ConfigError { BadFilePath(ref p, _) => write!(f, "{:?} is not a valid config path", p), BadEnv(ref e) => write!(f, "{:?} is not a valid `ROCKET_ENV` value", e), ParseError(..) => write!(f, "the config file contains invalid TOML"), + UnknownKey(ref k) => write!(f, "'{}' is an unknown key", k), BadEntry(ref e, _) => { write!(f, "{:?} is not a valid `[environment]` entry", e) } @@ -148,6 +156,7 @@ impl Error for ConfigError { ParseError(..) => "the config file contains invalid TOML", BadType(..) => "a key was specified with a value of the wrong type", BadEnvVal(..) => "an environment variable could not be parsed", + UnknownKey(..) => "an unknown key was used in a disallowed position", } } } diff --git a/lib/src/config/mod.rs b/lib/src/config/mod.rs index 401eeec7..488d03ba 100644 --- a/lib/src/config/mod.rs +++ b/lib/src/config/mod.rs @@ -43,6 +43,10 @@ //! * **session_key**: _[string]_ a 256-bit base64 encoded string (44 //! characters) to use as the session key //! * example: `"8Xui8SN4mI+7egV/9dlfYYLGQJeEx4+DwmSQLwDVXJg="` +//! * **tls**: _[dict]_ a dictionary with two keys: 1) `certs`: _[string]_ a +//! path to a certificate chain in PEM format, and 2) `key`: _[string]_ a +//! path to a private key file in PEM format for the certificate in `certs` +//! * example: `{ certs = "/path/to/certs.pem", key = "/path/to/key.pem" }` //! //! ### Rocket.toml //! @@ -118,6 +122,31 @@ //! //! Environment variables take precedence over all other configuration methods: //! if the variable is set, it will be used as the value for the parameter. +//! Variable values are parsed as if they were TOML syntax. As illustration, +//! consider the following examples: +//! +//! ```sh +//! ROCKET_INTEGER=1 +//! ROCKET_FLOAT=3.14 +//! ROCKET_STRING=Hello +//! ROCKET_STRING="Hello" +//! ROCKET_BOOL=true +//! ROCKET_ARRAY=[1,"b",3.14] +//! ROCKET_DICT={key="abc",val=123} +//! ``` +//! +//! ### TLS Configuration +//! +//! TLS can be enabled by specifying the `tls.key` and `tls.certs` parameters. +//! Rocket must be compiled with the `tls` feature enabled for the parameters to +//! take effect. The recommended way to specify the parameters is via the +//! `global` environment: +//! +//! ```toml +//! [global.tls] +//! certs = "/path/to/certs.pem" +//! key = "/path/to/key.pem" +//! ``` //! //! ## Retrieving Configuration Parameters //! @@ -278,6 +307,7 @@ impl RocketConfig { Err(ConfigError::NotFound) } + #[inline] fn get_mut(&mut self, env: Environment) -> &mut Config { match self.config.get_mut(&env) { Some(config) => config, @@ -306,33 +336,37 @@ impl RocketConfig { } /// Retrieves the `Config` for the active environment. + #[inline] pub fn active(&self) -> &Config { self.get(self.active_env) } // Override all environments with values from env variables if present. fn override_from_env(&mut self) -> Result<()> { - 'outer: for (env_key, env_val) in env::vars() { - if env_key.len() < ENV_VAR_PREFIX.len() { + for (key, val) in env::vars() { + if key.len() < ENV_VAR_PREFIX.len() { continue - } else if !uncased_eq(&env_key[..ENV_VAR_PREFIX.len()], ENV_VAR_PREFIX) { + } else if !uncased_eq(&key[..ENV_VAR_PREFIX.len()], ENV_VAR_PREFIX) { continue } // Skip environment variables that are handled elsewhere. - for prehandled_var in PREHANDLED_VARS.iter() { - if uncased_eq(&env_key, &prehandled_var) { - continue 'outer - } + if PREHANDLED_VARS.iter().any(|var| uncased_eq(&key, var)) { + continue } // Parse the key and value and try to set the variable for all envs. - let key = env_key[ENV_VAR_PREFIX.len()..].to_lowercase(); - let val = parse_simple_toml_value(&env_val); + let key = key[ENV_VAR_PREFIX.len()..].to_lowercase(); + let toml_val = match parse_simple_toml_value(&val) { + Ok(val) => val, + Err(e) => return Err(ConfigError::BadEnvVal(key, val, e.into())) + }; + for env in &Environment::all() { - match self.get_mut(*env).set_raw(&key, &val) { - Err(ConfigError::BadType(_, exp, _, _)) => { - return Err(ConfigError::BadEnvVal(env_key, env_val, exp)) + match self.get_mut(*env).set_raw(&key, &toml_val) { + Err(ConfigError::BadType(_, exp, actual, _)) => { + let e = format!("expected {}, but found {}", exp, actual); + return Err(ConfigError::BadEnvVal(key, val, e)) } Err(e) => return Err(e), Ok(_) => { /* move along */ } @@ -458,7 +492,7 @@ unsafe fn private_init() { let config = RocketConfig::read().unwrap_or_else(|e| { match e { ParseError(..) | BadEntry(..) | BadEnv(..) | BadType(..) - | BadFilePath(..) | BadEnvVal(..) => bail(e), + | BadFilePath(..) | BadEnvVal(..) | UnknownKey(..) => bail(e), IOError | BadCWD => warn!("Failed reading Rocket.toml. Using defaults."), NotFound => { /* try using the default below */ } } @@ -697,6 +731,64 @@ mod test { "#.to_string(), TEST_CONFIG_FILENAME).is_err()); } + // Only do this test when the tls feature is disabled since the file paths + // we're supplying don't actually exist. + #[test] + fn test_good_tls_values() { + // Take the lock so changing the environment doesn't cause races. + let _env_lock = ENV_LOCK.lock().unwrap(); + env::set_var(CONFIG_ENV, "dev"); + + assert!(RocketConfig::parse(r#" + [staging] + tls = { certs = "some/path.pem", key = "some/key.pem" } + "#.to_string(), TEST_CONFIG_FILENAME).is_ok()); + + assert!(RocketConfig::parse(r#" + [staging.tls] + certs = "some/path.pem" + key = "some/key.pem" + "#.to_string(), TEST_CONFIG_FILENAME).is_ok()); + + assert!(RocketConfig::parse(r#" + [global.tls] + certs = "some/path.pem" + key = "some/key.pem" + "#.to_string(), TEST_CONFIG_FILENAME).is_ok()); + + assert!(RocketConfig::parse(r#" + [global] + tls = { certs = "some/path.pem", key = "some/key.pem" } + "#.to_string(), TEST_CONFIG_FILENAME).is_ok()); + } + + #[test] + fn test_bad_tls_config() { + // Take the lock so changing the environment doesn't cause races. + let _env_lock = ENV_LOCK.lock().unwrap(); + env::remove_var(CONFIG_ENV); + + assert!(RocketConfig::parse(r#" + [development] + tls = "hello" + "#.to_string(), TEST_CONFIG_FILENAME).is_err()); + + assert!(RocketConfig::parse(r#" + [development] + tls = { certs = "some/path.pem" } + "#.to_string(), TEST_CONFIG_FILENAME).is_err()); + + assert!(RocketConfig::parse(r#" + [development] + tls = { certs = "some/path.pem", key = "some/key.pem", extra = "bah" } + "#.to_string(), TEST_CONFIG_FILENAME).is_err()); + + assert!(RocketConfig::parse(r#" + [staging] + tls = { cert = "some/path.pem", key = "some/key.pem" } + "#.to_string(), TEST_CONFIG_FILENAME).is_err()); + } + #[test] fn test_good_port_values() { // Take the lock so changing the environment doesn't cause races. diff --git a/lib/src/config/toml_ext.rs b/lib/src/config/toml_ext.rs index cc2a107a..ac2b7e21 100644 --- a/lib/src/config/toml_ext.rs +++ b/lib/src/config/toml_ext.rs @@ -4,20 +4,63 @@ use std::str::FromStr; use config::Value; -pub fn parse_simple_toml_value(string: &str) -> Value { - if let Ok(int) = i64::from_str(string) { - return Value::Integer(int) +pub fn parse_simple_toml_value(string: &str) -> Result { + if string.is_empty() { + return Err("value is empty") } - if let Ok(boolean) = bool::from_str(string) { - return Value::Boolean(boolean) - } + let value = if let Ok(int) = i64::from_str(string) { + Value::Integer(int) + } else if let Ok(float) = f64::from_str(string) { + Value::Float(float) + } else if let Ok(boolean) = bool::from_str(string) { + Value::Boolean(boolean) + } else if string.starts_with('{') { + if !string.ends_with('}') { + return Err("value is missing closing '}'") + } - if let Ok(float) = f64::from_str(string) { - return Value::Float(float) - } + let mut table = BTreeMap::new(); + let inner = &string[1..string.len() - 1].trim(); + if !inner.is_empty() { + for key_val in inner.split(',') { + let (key, val) = match key_val.find('=') { + Some(i) => (&key_val[..i], &key_val[(i + 1)..]), + None => return Err("missing '=' in dicitonary key/value pair") + }; - Value::String(string.to_string()) + let key = key.trim().to_string(); + let val = parse_simple_toml_value(val.trim())?; + table.insert(key, val); + } + } + + Value::Table(table) + } else if string.starts_with('[') { + if !string.ends_with(']') { + return Err("value is missing closing ']'") + } + + let mut vals = vec![]; + let inner = &string[1..string.len() - 1].trim(); + if !inner.is_empty() { + for val_str in inner.split(',') { + vals.push(parse_simple_toml_value(val_str.trim())?); + } + } + + Value::Array(vals) + } else if string.starts_with('"') { + if !string[1..].ends_with('"') { + return Err("value is missing closing '\"'"); + } + + Value::String(string[1..string.len() - 1].to_string()) + } else { + Value::String(string.to_string()) + }; + + Ok(value) } /// Conversion trait from standard types into TOML `Value`s. @@ -27,18 +70,21 @@ pub trait IntoValue { } impl<'a> IntoValue for &'a str { + #[inline(always)] fn into_value(self) -> Value { Value::String(self.to_string()) } } impl IntoValue for Value { + #[inline(always)] fn into_value(self) -> Value { self } } impl IntoValue for Vec { + #[inline(always)] fn into_value(self) -> Value { Value::Array(self.into_iter().map(|v| v.into_value()).collect()) } @@ -87,3 +133,50 @@ impl_into_value!(Boolean: bool); impl_into_value!(Float: f64); impl_into_value!(Float: f32, as f64); +#[cfg(test)] +mod test { + use std::collections::BTreeMap; + use super::parse_simple_toml_value; + use super::IntoValue; + use super::Value::*; + + macro_rules! assert_parse { + ($string:expr, $value:expr) => ( + match parse_simple_toml_value($string) { + Ok(value) => assert_eq!(value, $value), + Err(e) => panic!("{:?} failed to parse: {:?}", $string, e) + }; + ) + } + + #[test] + fn parse_toml_values() { + assert_parse!("1", Integer(1)); + assert_parse!("1.32", Float(1.32)); + assert_parse!("true", Boolean(true)); + assert_parse!("false", Boolean(false)); + assert_parse!("hello, WORLD!", String("hello, WORLD!".into())); + assert_parse!("hi", String("hi".into())); + assert_parse!("\"hi\"", String("hi".into())); + + assert_parse!("[]", Array(Vec::new())); + assert_parse!("[1]", vec![1].into_value()); + assert_parse!("[1, 2, 3]", vec![1, 2, 3].into_value()); + assert_parse!("[1.32, 2]", + vec![1.32.into_value(), 2.into_value()].into_value()); + + assert_parse!("{}", Table(BTreeMap::new())); + assert_parse!("{a=b}", Table({ + let mut map = BTreeMap::new(); + map.insert("a".into(), "b".into_value()); + map + })); + assert_parse!("{v=1, on=true,pi=3.14}", Table({ + let mut map = BTreeMap::new(); + map.insert("v".into(), 1.into_value()); + map.insert("on".into(), true.into_value()); + map.insert("pi".into(), 3.14.into_value()); + map + })); + } +} diff --git a/lib/src/data/data.rs b/lib/src/data/data.rs index 0000336e..d6e92e1d 100644 --- a/lib/src/data/data.rs +++ b/lib/src/data/data.rs @@ -4,9 +4,10 @@ use std::fs::File; use std::time::Duration; use std::mem::transmute; -use super::data_stream::{DataStream, StreamReader, kill_stream}; +#[cfg(feature = "tls")] use hyper_rustls::WrappedStream; use ext::ReadExt; +use super::data_stream::{DataStream, HyperNetStream, StreamReader, kill_stream}; use http::hyper::h1::HttpReader; use http::hyper::buffer; @@ -82,32 +83,47 @@ impl Data { DataStream::new(stream, network) } + // FIXME: This is absolutely terrible (downcasting!), thanks to Hyper. pub(crate) fn from_hyp(mut h_body: BodyReader) -> Result { - // FIXME: This is asolutely terrible, thanks to Hyper. + // Create the Data object from hyper's buffer. + let (vec, pos, cap) = h_body.get_mut().take_buf(); + let net_stream = h_body.get_ref().get_ref(); + + #[cfg(feature = "tls")] + fn concrete_stream(stream: &&mut NetworkStream) -> Option { + stream.downcast_ref::() + .map(|s| HyperNetStream::Http(s.clone())) + .or_else(|| { + stream.downcast_ref::() + .map(|s| HyperNetStream::Https(s.clone())) + }) + } + + #[cfg(not(feature = "tls"))] + fn concrete_stream(stream: &&mut NetworkStream) -> Option { + stream.downcast_ref::() + .map(|s| HyperNetStream::Http(s.clone())) + } // Retrieve the underlying HTTPStream from Hyper. - let mut stream = match h_body.get_ref().get_ref() - .downcast_ref::() { - Some(s) => { - let owned_stream = s.clone(); - let buf_len = h_body.get_ref().get_buf().len() as u64; - match h_body { - SizedReader(_, n) => SizedReader(owned_stream, n - buf_len), - EofReader(_) => EofReader(owned_stream), - EmptyReader(_) => EmptyReader(owned_stream), - ChunkedReader(_, n) => - ChunkedReader(owned_stream, n.map(|k| k - buf_len)), - } - }, - None => return Err("Stream is not an HTTP stream!"), + let stream = match concrete_stream(net_stream) { + Some(stream) => stream, + None => return Err("Stream is not an HTTP(s) stream!") }; // Set the read timeout to 5 seconds. - stream.get_mut().set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + stream.set_read_timeout(Some(Duration::from_secs(5))).expect("timeout set"); - // Create the Data object from hyper's buffer. - let (vec, pos, cap) = h_body.get_mut().take_buf(); - Ok(Data::new(vec, pos, cap, stream)) + // Create a reader from the stream. Don't read what's already buffered. + let buffered = (cap - pos) as u64; + let reader = match h_body { + SizedReader(_, n) => SizedReader(stream, n - buffered), + EofReader(_) => EofReader(stream), + EmptyReader(_) => EmptyReader(stream), + ChunkedReader(_, n) => ChunkedReader(stream, n.map(|k| k - buffered)), + }; + + Ok(Data::new(vec, pos, cap, reader)) } /// Retrieve the `peek` buffer. @@ -151,10 +167,10 @@ impl Data { // in the buffer is at `pos` and the buffer has `cap` valid bytes. The // remainder of the data bytes can be read from `stream`. pub(crate) fn new(mut buf: Vec, - pos: usize, - mut cap: usize, - mut stream: StreamReader) - -> Data { + pos: usize, + mut cap: usize, + mut stream: StreamReader + ) -> Data { // Make sure the buffer is large enough for the bytes we want to peek. const PEEK_BYTES: usize = 4096; if buf.len() < PEEK_BYTES { @@ -198,4 +214,3 @@ impl Drop for Data { } } } - diff --git a/lib/src/data/data_stream.rs b/lib/src/data/data_stream.rs index 1eeb8838..a63371d4 100644 --- a/lib/src/data/data_stream.rs +++ b/lib/src/data/data_stream.rs @@ -1,12 +1,80 @@ use std::io::{self, BufRead, Read, Cursor, BufReader, Chain, Take}; -use std::net::Shutdown; +use std::net::{SocketAddr, Shutdown}; +use std::time::Duration; + +#[cfg(feature = "tls")] use hyper_rustls::WrappedStream as RustlsStream; use http::hyper::net::{HttpStream, NetworkStream}; use http::hyper::h1::HttpReader; -pub type StreamReader = HttpReader; +pub type StreamReader = HttpReader; pub type InnerStream = Chain>>, BufReader>; +#[derive(Clone)] +pub enum HyperNetStream { + Http(HttpStream), + #[cfg(feature = "tls")] + Https(RustlsStream) +} + +macro_rules! with_inner { + ($net:expr, |$stream:ident| $body:expr) => ({ + trace!("{}:{}", file!(), line!()); + match *$net { + HyperNetStream::Http(ref $stream) => $body, + #[cfg(feature = "tls")] HyperNetStream::Https(ref $stream) => $body + } + }); + ($net:expr, |mut $stream:ident| $body:expr) => ({ + trace!("{}:{}", file!(), line!()); + match *$net { + HyperNetStream::Http(ref mut $stream) => $body, + #[cfg(feature = "tls")] HyperNetStream::Https(ref mut $stream) => $body + } + }) +} + +impl io::Read for HyperNetStream { + #[inline(always)] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + with_inner!(self, |mut stream| io::Read::read(stream, buf)) + } +} + +impl io::Write for HyperNetStream { + #[inline(always)] + fn write(&mut self, buf: &[u8]) -> io::Result { + with_inner!(self, |mut stream| io::Write::write(stream, buf)) + } + + #[inline(always)] + fn flush(&mut self) -> io::Result<()> { + with_inner!(self, |mut stream| io::Write::flush(stream)) + } +} + +impl NetworkStream for HyperNetStream { + #[inline(always)] + fn peer_addr(&mut self) -> io::Result { + with_inner!(self, |mut stream| NetworkStream::peer_addr(stream)) + } + + #[inline(always)] + fn set_read_timeout(&self, dur: Option) -> io::Result<()> { + with_inner!(self, |stream| NetworkStream::set_read_timeout(stream, dur)) + } + + #[inline(always)] + fn set_write_timeout(&self, dur: Option) -> io::Result<()> { + with_inner!(self, |stream| NetworkStream::set_write_timeout(stream, dur)) + } + + #[inline(always)] + fn close(&mut self, how: Shutdown) -> io::Result<()> { + with_inner!(self, |mut stream| NetworkStream::close(stream, how)) + } +} + /// Raw data stream of a request body. /// /// This stream can only be obtained by calling @@ -15,33 +83,38 @@ pub type InnerStream = Chain>>, BufReader>; /// Instead, it must be used as an opaque `Read` or `BufRead` structure. pub struct DataStream { stream: InnerStream, - network: HttpStream, + network: HyperNetStream, } impl DataStream { - pub(crate) fn new(stream: InnerStream, network: HttpStream) -> DataStream { - DataStream { stream: stream, network: network, } + #[inline(always)] + pub(crate) fn new(stream: InnerStream, network: HyperNetStream) -> DataStream { + DataStream { stream, network } } } impl Read for DataStream { + #[inline(always)] fn read(&mut self, buf: &mut [u8]) -> io::Result { self.stream.read(buf) } } impl BufRead for DataStream { + #[inline(always)] fn fill_buf(&mut self) -> io::Result<&[u8]> { self.stream.fill_buf() } + #[inline(always)] fn consume(&mut self, amt: usize) { self.stream.consume(amt) } } +// pub fn kill_stream(stream: &mut S, network: &mut HyperNetStream) { pub fn kill_stream(stream: &mut S, network: &mut N) { - io::copy(&mut stream.take(1024), &mut io::sink()).expect("sink"); + io::copy(&mut stream.take(1024), &mut io::sink()).expect("kill_stream: sink"); // If there are any more bytes, kill it. let mut buf = [0]; @@ -61,4 +134,3 @@ impl Drop for DataStream { kill_stream(&mut self.stream, &mut self.network); } } - diff --git a/lib/src/error.rs b/lib/src/error.rs index a9ce4c02..ebbf14c5 100644 --- a/lib/src/error.rs +++ b/lib/src/error.rs @@ -111,6 +111,7 @@ impl LaunchError { } impl From for LaunchError { + #[inline] fn from(error: hyper::Error) -> LaunchError { match error { hyper::Error::Io(e) => LaunchError::new(LaunchErrorKind::Io(e)), @@ -120,6 +121,7 @@ impl From for LaunchError { } impl fmt::Display for LaunchErrorKind { + #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { LaunchErrorKind::Io(ref e) => write!(f, "I/O error: {}", e), @@ -129,6 +131,7 @@ impl fmt::Display for LaunchErrorKind { } impl fmt::Debug for LaunchError { + #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.mark_handled(); write!(f, "{:?}", self.kind()) @@ -136,6 +139,7 @@ impl fmt::Debug for LaunchError { } impl fmt::Display for LaunchError { + #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.mark_handled(); write!(f, "{}", self.kind()) @@ -143,6 +147,7 @@ impl fmt::Display for LaunchError { } impl ::std::error::Error for LaunchError { + #[inline] fn description(&self) -> &str { self.mark_handled(); match *self.kind() { diff --git a/lib/src/lib.rs b/lib/src/lib.rs index f69710f8..b2842639 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -7,6 +7,7 @@ #![feature(lookup_host)] #![feature(plugin)] #![feature(never_type)] +#![feature(concat_idents)] #![plugin(pear_codegen)] @@ -99,6 +100,8 @@ #[macro_use] extern crate log; #[macro_use] extern crate pear; +#[cfg(feature = "tls")] extern crate rustls; +#[cfg(feature = "tls")] extern crate hyper_rustls; extern crate term_painter; extern crate hyper; extern crate url; diff --git a/lib/src/logger.rs b/lib/src/logger.rs index 39ec56f7..2550dd7c 100644 --- a/lib/src/logger.rs +++ b/lib/src/logger.rs @@ -32,13 +32,14 @@ impl LoggingLevel { } impl FromStr for LoggingLevel { - type Err = (); + type Err = &'static str; + fn from_str(s: &str) -> Result { let level = match s { "critical" => LoggingLevel::Critical, "normal" => LoggingLevel::Normal, "debug" => LoggingLevel::Debug, - _ => return Err(()) + _ => return Err("a log level (debug, normal, critical)") }; Ok(level) @@ -87,9 +88,10 @@ impl Log for RocketLogger { return; } - // Don't print Hyper's messages unless Debug is enabled. + // Don't print Hyper or Rustls messages unless debug is enabled. let from_hyper = record.location().module_path().starts_with("hyper::"); - if from_hyper && self.0 != LoggingLevel::Debug { + let from_rustls = record.location().module_path().starts_with("rustls::"); + if self.0 != LoggingLevel::Debug && (from_hyper || from_rustls) { return; } diff --git a/lib/src/rocket.rs b/lib/src/rocket.rs index 22734ea6..0dae48b9 100644 --- a/lib/src/rocket.rs +++ b/lib/src/rocket.rs @@ -6,9 +6,9 @@ use std::io::{self, Write}; use term_painter::Color::*; use term_painter::ToStyle; - use state::Container; +#[cfg(feature = "tls")] use hyper_rustls::TlsServer; use {logger, handler}; use ext::ReadExt; use config::{self, Config}; @@ -74,6 +74,35 @@ impl hyper::Handler for Rocket { } } +// This macro is a terrible hack to get around Hyper's Server type. What we +// want is to use almost exactly the same launch code when we're serving over +// HTTPS as over HTTP. But Hyper forces two different types, so we can't use the +// same code, at least not trivially. These macros get around that by passing in +// the same code as a continuation in `$continue`. This wouldn't work as a +// regular function taking in a closure because the types of the inputs to the +// closure would be different depending on whether TLS was enabled or not. +#[cfg(not(feature = "tls"))] +macro_rules! serve { + ($rocket:expr, $addr:expr, |$server:ident, $proto:ident| $continue:expr) => ({ + let ($proto, $server) = ("http://", hyper::Server::http($addr)); + $continue + }) +} + +#[cfg(feature = "tls")] +macro_rules! serve { + ($rocket:expr, $addr:expr, |$server:ident, $proto:ident| $continue:expr) => ({ + if let Some(ref tls) = $rocket.config.tls { + let tls = TlsServer::new(tls.certs.clone(), tls.key.clone()); + let ($proto, $server) = ("https://", hyper::Server::https($addr, tls)); + $continue + } else { + let ($proto, $server) = ("http://", hyper::Server::http($addr)); + $continue + } + }) +} + impl Rocket { #[inline] fn issue_response(&self, mut response: Response, hyp_res: hyper::FreshResponse) { @@ -262,7 +291,11 @@ impl Rocket { Outcome::Forward(data) } - // TODO: DOC. + // Finds the error catcher for the status `status` and executes it fo the + // given request `req`. If a user has registere a catcher for `status`, the + // catcher is called. If the catcher fails to return a good response, the + // 500 catcher is executed. if there is no registered catcher for `status`, + // the default catcher is used. fn handle_error<'r>(&self, status: Status, req: &'r Request) -> Response<'r> { warn_!("Responding with {} catcher.", Red.paint(&status)); @@ -350,6 +383,16 @@ impl Rocket { info_!("workers: {}", White.paint(config.workers)); info_!("session key: {}", White.paint(config.session_key.kind())); + let tls_configured = config.tls.is_some(); + if tls_configured && cfg!(feature = "tls") { + info_!("tls: {}", White.paint("enabled")); + } else { + info_!("tls: {}", White.paint("disabled")); + if tls_configured { + warn_!("tls is configured, but the tls feature is disabled"); + } + } + for (name, value) in config.extras() { info_!("{} {}: {}", Yellow.paint("[extra]"), name, White.paint(value)); } @@ -553,22 +596,24 @@ impl Rocket { } let full_addr = format!("{}:{}", self.config.address, self.config.port); - let server = match hyper::Server::http(full_addr.as_str()) { - Ok(hyper_server) => hyper_server, - Err(e) => return LaunchError::from(e) - }; + serve!(self, &full_addr, |server, proto| { + let server = match server { + Ok(server) => server, + Err(e) => return LaunchError::from(e) + }; - info!("🚀 {} {}{}", - White.paint("Rocket has launched from"), - White.bold().paint("http://"), - White.bold().paint(&full_addr)); + info!("🚀 {} {}{}", + White.paint("Rocket has launched from"), + White.bold().paint(proto), + White.bold().paint(&full_addr)); - let threads = self.config.workers as usize; - if let Err(e) = server.handle_threads(self, threads) { - return LaunchError::from(e); - } + let threads = self.config.workers as usize; + if let Err(e) = server.handle_threads(self, threads) { + return LaunchError::from(e); + } - unreachable!("the call to `handle_threads` should block on success") + unreachable!("the call to `handle_threads` should block on success") + }) } /// Retrieves all of the mounted routes.