Initial TLS support.

This commit introduces TLS support, provided by `rustls` and a fork of
`hyper-rustls`. TLS support is enabled via the `tls` feature and
activated when the `tls` configuration parameter is set. A new
`hello_tls` example illustrates its usage.

This commit also introduces more robust and complete configuration
settings via environment variables. In particular, quoted string,
array, and table (dictionaries) based configuration parameters can now
be set via environment variables.

Resolves #28.
This commit is contained in:
Sergio Benitez 2017-04-13 00:18:31 -07:00
parent cf47daa8e1
commit 1516ca4fb6
19 changed files with 739 additions and 145 deletions

View File

@ -33,4 +33,5 @@ members = [
"examples/uuid", "examples/uuid",
"examples/session", "examples/session",
"examples/raw_sqlite", "examples/raw_sqlite",
"examples/hello_tls",
] ]

View File

@ -0,0 +1,11 @@
[package]
name = "hello_tls"
version = "0.0.0"
workspace = "../../"
[dependencies]
rocket = { path = "../../lib", features = ["tls"] }
rocket_codegen = { path = "../../codegen" }
[dev-dependencies]
rocket = { path = "../../lib", features = ["testing"] }

View File

@ -0,0 +1,11 @@
# The certificate/private key pair used here was generated via openssl:
#
# openssl req -x509 -newkey rsa:4096 -nodes -sha256 -days 3650 \
# -keyout key.pem -out cert.pem
#
# The certificate is self-signed. As such, you will need to trust it directly
# for your browser to refer to the connection as secure. You should NEVER use
# this certificate/key pair. It is here for DEMONSTRATION PURPOSES ONLY.
[global.tls]
certs = "private/cert.pem"
key = "private/key.pem"

View File

@ -0,0 +1,37 @@
-----BEGIN CERTIFICATE-----
MIIGXjCCBEagAwIBAgIJAJBTO2YLMz4tMA0GCSqGSIb3DQEBCwUAMHwxCzAJBgNV
BAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMREwDwYDVQQHEwhTdGFuZm9yZDEP
MA0GA1UEChMGUm9ja2V0MRcwFQYDVQQDEw5TZXJnaW8gQmVuaXRlejEbMBkGCSqG
SIb3DQEJARYMc2JAcm9ja2V0LnJzMB4XDTE3MDQwOTA0MTIxM1oXDTI3MDQwNzA0
MTIxM1owfDELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExETAPBgNV
BAcTCFN0YW5mb3JkMQ8wDQYDVQQKEwZSb2NrZXQxFzAVBgNVBAMTDlNlcmdpbyBC
ZW5pdGV6MRswGQYJKoZIhvcNAQkBFgxzYkByb2NrZXQucnMwggIiMA0GCSqGSIb3
DQEBAQUAA4ICDwAwggIKAoICAQCzdm0ZxLNP4TlJBI2IpeVT4S6hZeBkem/aj4NZ
mhHA06HXVqcUw3W03YQklhO7E305uU/BTRz5q0BIa2DCPyZDUCkwTjOZAuFiiZzc
AZz/zhu2RwLWeYttlvjKewrIe0k9zrPaPXpdcFe0xq2mcUon0fyRztL1H8EYEScb
/TJqM1LkWKGSJEOMDeEYMVnJn/x9yFgfC82u/4GBc3q3Si2uRLCMkTLsg6TC27EF
kCVuOISf1+CvAKgk2x29SGm/nYoTe+j6YLm12h41S6JlGO9zJnORlwb4Mz5h+72p
NBaVER72kNxwskTNg2IWur7NM2Xi/nAfZ7+YOopgwosRuZl8Nw6CcpWDkGdLnO7X
H18Wy/BXOamXVa65tWefwlCiJ8bkqZgik8AHX36KZzTzkDO5g/4JAQDh4G56paGu
hcd1LXkGvTDuaSN4BkHDuYucr89aliWV/AKzum4BJkyKk3lVWDb9nfwyTRegsZg5
ipTW7xLhvxzjeoLuDHybRzsw+2NFQoHA4PUouzC1n2/+eJIVysa6p5UZXTcTNGVd
rdU3GmifpFDBv4NwQrQ1y2izw0b+dbZ7DBAQIqW3toHeBUmmTiHSmQR5QT3Dz9HA
l2npMu4S2ZKQYJj+zqxyETzrOgz76LW1yZ3uAbX7z0OxlOoC67XYGAAWlDyU4pZc
qcnR1QIDAQABo4HiMIHfMB0GA1UdDgQWBBSUdwqW1sNXbeS29wXaJL0P9glPmTCB
rwYDVR0jBIGnMIGkgBSUdwqW1sNXbeS29wXaJL0P9glPmaGBgKR+MHwxCzAJBgNV
BAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMREwDwYDVQQHEwhTdGFuZm9yZDEP
MA0GA1UEChMGUm9ja2V0MRcwFQYDVQQDEw5TZXJnaW8gQmVuaXRlejEbMBkGCSqG
SIb3DQEJARYMc2JAcm9ja2V0LnJzggkAkFM7ZgszPi0wDAYDVR0TBAUwAwEB/zAN
BgkqhkiG9w0BAQsFAAOCAgEAXfsz3a1iDWoyyobVlFV3NvVD5CeZ9oWh/IvmbgfB
XMEB3ZLy3Cqn4op1u6Kbo/L8E+YYMlIGdqUtAlpJrtF1k3KeVGBScx6YebCPWcuL
aRO/l1qUR70RhA/Yrz0iNqcTSkC2n9YYFtr7tOiTzS3kqN03XB6fJBsYnVG4LzR5
1EdNjgGoISSVmKJwQh0Sjy+GuHDnsjtL5xPxf5OFA6bgJiYkpMgxv0VDzC3Bl6NL
oTYwnQ+/19yzoSZANlvwKi8UftHEpBXMAW2Yr3jEfKuSQIe3FPr2js2JWfpl12yQ
JZnDPEJOamxD4hvvWljENwcVMss9X9FGiQCoFIYmGja4JXZ7KqvOlgdSaS8TUqCp
qHcSJpEiJQAJQton607EjWxBBWVEMEQYZx5nLFifxwexxv1jpbGeh+ehAwRlvrZU
nXR9miv/ohw6HmopNXmXcTCJsT8/OHb4g7cUs3scUmuySMZe1dKht+o+XmkWfF9b
fgqNz9so1ls9oyg/qjuMwh5wNNUsPQJGITmzTOfGGu7engyil744flO5aQFfSwcm
zQ7ZzRh+jDPI/rG8xbrYpXXK3+xln03O/96AC6iEELA0+A2PeworEkz47nEPrqlC
Fr17Aya+rJsrN9JXL1Uz87k3XfySNc6xT8zitNGzgxtgKthUg6fU9oOt/79HuOeL
g+Y=
-----END CERTIFICATE-----

View File

@ -0,0 +1,51 @@
-----BEGIN RSA PRIVATE KEY-----
MIIJKQIBAAKCAgEAs3ZtGcSzT+E5SQSNiKXlU+EuoWXgZHpv2o+DWZoRwNOh11an
FMN1tN2EJJYTuxN9OblPwU0c+atASGtgwj8mQ1ApME4zmQLhYomc3AGc/84btkcC
1nmLbZb4ynsKyHtJPc6z2j16XXBXtMatpnFKJ9H8kc7S9R/BGBEnG/0yajNS5Fih
kiRDjA3hGDFZyZ/8fchYHwvNrv+BgXN6t0otrkSwjJEy7IOkwtuxBZAlbjiEn9fg
rwCoJNsdvUhpv52KE3vo+mC5tdoeNUuiZRjvcyZzkZcG+DM+Yfu9qTQWlREe9pDc
cLJEzYNiFrq+zTNl4v5wH2e/mDqKYMKLEbmZfDcOgnKVg5BnS5zu1x9fFsvwVzmp
l1WuubVnn8JQoifG5KmYIpPAB19+imc085AzuYP+CQEA4eBueqWhroXHdS15Br0w
7mkjeAZBw7mLnK/PWpYllfwCs7puASZMipN5VVg2/Z38Mk0XoLGYOYqU1u8S4b8c
43qC7gx8m0c7MPtjRUKBwOD1KLswtZ9v/niSFcrGuqeVGV03EzRlXa3VNxpon6RQ
wb+DcEK0Nctos8NG/nW2ewwQECKlt7aB3gVJpk4h0pkEeUE9w8/RwJdp6TLuEtmS
kGCY/s6schE86zoM++i1tcmd7gG1+89DsZTqAuu12BgAFpQ8lOKWXKnJ0dUCAwEA
AQKCAgBIdkLrKq80S75zqzDywfls+vl3FcmbCIztdREWNs2ATHOGnWhtS9bVJrRa
iXaCDQZ9LkPzyw0uCmW0WBcDl7f9afqXlJvk5nLW9LWvZ79a0oACA34z13Pi1hiy
uSfLd2xFVpbsQfKMk/X1+lrXX9sPZQxUW2x2qVGwRAzEkmGu2/ZWWSsz9QyJGnmO
6S5V6RFsQF7EemGcjXJfMJ+WLo9vVDDtMRucwDLgsxAxLNjQPmXenK4OO3epGghS
C1EXm6bK4zdZEYEq2l1kK5vwsjbNCfOUD6Uyxo4jxh/4mB2eJwGXkTpRDsoVKT2L
6+9qr5wuIYpoQ93qu4hwNV0t1QERp4HYTrX6WzDPLCtoHlIfCfzXH05jYa9n/nJD
Ow4eeeK5RE7/9/fKJWX2/R45iz87QKeS2H+ps3IOirK1P4Y/4FdIoHVyHjaCeoGp
YeSXjTWz6OEcHX9qcShdVu8ILkJlBJ4xftxNKWw8d/jKG3xOaxVNAnTGGAx113Pa
RIZcGNfroQcceV0mRAQTTpZyV7jj+lvC9lXqyP4tJJc4YjSWG6ITPrYfKVYBAhNC
K8MutrCaEH87SzIssY3DvpSzIeZMafvR4VYHZbBsDpI8+9iIijXLzDR0IQUANcah
bLWzjhDXrE5f1Et1hyHCAkCzsPIyx67QugYpJ8zTDUp6oYVaaQKCAQEA3XjPYrAJ
kokPk8ErDYTe0lMRGKvIrV3TH5hkfTIlkMfn4M6wVS3WQYYEMwgcJHVILpHkxtuD
ZcNIrey5f3IlwvSGo30w9L4IQCt5nzU0GpYqeN4Uty+0jDO0pygPYRrNAedSQSYa
jgSwoQCiu/qVzb+fPD2n6Hfxn/+cBP/mw6JBZZRb4aIhAluIt958J7NB/AwY+o03
rG6lq41UUYDtydt/yP3q7QlRBFkY+anm7O7Iy80e4oPciXbettsMuHPa8/zF1TWJ
IxG2v9C4Q++tvj9zKP66l/5JVKVdtexGtPJS9i8lIkPkqX7K9peOUDNwyVAyG4kB
2SxKysu4sTq5ZwKCAQEAz3D6UTbTnzkrQFDnKaYUpMW4jfKUFr/n+s5iIN3FPk9t
ZuZ3YbrIroW50KxmVYFG+x1vyEZvB+Bdb7apCnqK4z2x4vbpl4bsE/uW8Kk+rQxx
gbUfuih1r0FMVE0kGL8MxJEL/4owOZM9G8hinxkVst4LswNe0MPwfJHgeLrx/xAM
lHq4hS5Pb4SYb+Z6iUJlzJpsQPX+JLW3cDqlfUBB9ckijtNkTbwsq7ETWRglnAP1
flOCBe6l1aSDPiSa3fRSwA0PHTztVZt2TwnhDD2v40tTxCfeAwafeEwpUUgWivt0
Doq0Tni5lZiPjqNfmbZSa8BKyeEtPuKZcKT4QMSJYwKCAQA8+TTHa8XG5Rs3x5fN
ygX6i8oKK8k9CbbFXRRVb4fuG0tYli7v1IXHVlkzn4j39J4hzCLbKLY9Pw10bNcJ
ImkJCn9C5YWj6+mjmRSL437rzun0itfTMzwW2WlkF+BcEJ/eZUw9CXuIG/xw5xbm
f+/cTGRPln3yv4rzTNEsgzOKKtKsX7MIJLXHy2GRlZxC5dRFyyLZYCWywGe2Glvb
cI6G43qD4HxcNBNtCgaZPdCI7Ji1m0xken8uDV71ossWwTbHs5DXyTxvPkI8/v6s
HYGM/jT7VV4T2Hth5YEuQ9WXnZt/ka08iMqca37/cuxIYlEr63tQH2E15D7XJE09
5fgDAoIBAQDKC2hDofsMokoWIraEQlbpBgtzdkn2voPcLRg2msp6njIYf3DXp22/
TlBlhwVFUt0nyMwPbUrHiSh4npiWtDSCkJyqS4PJKojWDb4+ORnqwqvrgdadIrs9
L4SAt4Ho+GwfKIdfJeFCsr5aSRqFi5Eu3kbW3PmErNOXAR55eNwrah5WoBEI5spH
/AXdN8cx2ZH9borx2qbmand4wCZfkC6ujnEyW4Lek+GOeLI3nOVEyDZcDEogLQko
xUtvQ4fzlvziQdXuzGD9eKYK5bxkh9DAuaWk8I+0ssawDL5RhL0wMSog38guhjd8
FVP9wfJjbMlqWah+aOwAzARXStbhfouxAoIBAQCw0HQEHJL0nTLT0SdkEAZF12Fo
NMdTh68xtQ1y2papT87L6J8/WmK1/O32KAJ3XXikl0QMTkhjHqf+eA+27C8L+jIR
HPhB4OGdlOufk1QXoaX1Z4vVHfyyXASspfZ1ecxFsQdC/lPnQ+ir/skHI1NNC9oO
Q39d37tyoKCOyZUhD0IsT5+6vVgyj6EcwiCmgwZ7PI3MKKoXx7HmkZ9e4hVYbXP6
WsNsF2VKPHCre56T2FRT3xLxzN5uZOz0Hrau50Y/3RNuYiJJ2aufAYmv1r6HT/0W
BP2kmzmWlg1JJaU9vS9jZXabQZPp8bJ+fDtSZSa69AJESKRE9e3e895uW3gY
-----END RSA PRIVATE KEY-----

View File

@ -0,0 +1,15 @@
#![feature(plugin)]
#![plugin(rocket_codegen)]
extern crate rocket;
#[cfg(test)] mod tests;
#[get("/")]
fn hello() -> &'static str {
"Hello, world!"
}
fn main() {
rocket::ignite().mount("/", routes![hello]).launch();
}

View File

@ -0,0 +1,13 @@
use super::rocket;
use rocket::testing::MockRequest;
use rocket::http::Method::*;
#[test]
fn hello_world() {
let rocket = rocket::ignite().mount("/", routes![super::hello]);
let mut req = MockRequest::new(Get, "/");
let mut response = req.dispatch_with(&rocket);
let body_str = response.body().and_then(|body| body.into_string());
assert_eq!(body_str, Some("Hello, world!".to_string()));
}

View File

@ -16,6 +16,7 @@ categories = ["web-programming::http-server"]
[features] [features]
testing = [] testing = []
tls = ["rustls", "hyper-rustls"]
[dependencies] [dependencies]
term-painter = "0.2" term-painter = "0.2"
@ -31,6 +32,13 @@ base64 = "0.4"
smallvec = { git = "https://github.com/SergioBenitez/rust-smallvec" } smallvec = { git = "https://github.com/SergioBenitez/rust-smallvec" }
pear = "0.0.8" pear = "0.0.8"
pear_codegen = "0.0.8" pear_codegen = "0.0.8"
rustls = { version = "0.5.8", optional = true }
[dependencies.hyper-rustls]
git = "https://github.com/SergioBenitez/hyper-rustls"
default-features = false
features = ["server"]
optional = true
[dependencies.cookie] [dependencies.cookie]
version = "0.7.4" version = "0.7.4"

View File

@ -20,6 +20,8 @@ pub struct ConfigBuilder {
pub log_level: LoggingLevel, pub log_level: LoggingLevel,
/// The session key. /// The session key.
pub session_key: Option<String>, pub session_key: Option<String>,
/// TLS configuration (path to certificates file, path to private key file).
pub tls_config: Option<(String, String)>,
/// Any extra parameters that aren't part of Rocket's config. /// Any extra parameters that aren't part of Rocket's config.
pub extras: HashMap<String, Value>, pub extras: HashMap<String, Value>,
/// The root directory of this config. /// The root directory of this config.
@ -63,6 +65,7 @@ impl ConfigBuilder {
workers: config.workers, workers: config.workers,
log_level: config.log_level, log_level: config.log_level,
session_key: None, session_key: None,
tls_config: None,
extras: config.extras, extras: config.extras,
root: root_dir, root: root_dir,
} }
@ -162,6 +165,26 @@ impl ConfigBuilder {
self self
} }
/// Sets the `tls_config` in the configuration being built.
///
/// # Example
///
/// ```rust
/// use rocket::config::{Config, Environment};
///
/// let mut config = Config::build(Environment::Staging)
/// .tls("/path/to/certs.pem", "/path/to/key.pem")
/// # ; /*
/// .unwrap();
/// # */
/// ```
pub fn tls<C, K>(mut self, certs_path: C, key_path: K) -> Self
where C: Into<String>, K: Into<String>
{
self.tls_config = Some((certs_path.into(), key_path.into()));
self
}
/// Sets the `environment` in the configuration being built. /// Sets the `environment` in the configuration being built.
/// ///
/// # Example /// # Example
@ -260,6 +283,10 @@ impl ConfigBuilder {
config.set_extras(self.extras); config.set_extras(self.extras);
config.set_root(self.root); config.set_root(self.root);
if let Some((certs_path, key_path)) = self.tls_config {
config.set_tls(&certs_path, &key_path)?;
}
if let Some(key) = self.session_key { if let Some(key) = self.session_key {
config.set_session_key(key)?; config.set_session_key(key)?;
} }

View File

@ -5,10 +5,11 @@ use std::convert::AsRef;
use std::fmt; use std::fmt;
use std::env; use std::env;
use config::Environment::*; #[cfg(feature = "tls")] use rustls::{Certificate, PrivateKey};
use config::{self, Value, ConfigBuilder, Environment, ConfigError};
use {num_cpus, base64}; use {num_cpus, base64};
use config::Environment::*;
use config::{Result, Table, Value, ConfigBuilder, Environment, ConfigError};
use logger::LoggingLevel; use logger::LoggingLevel;
use http::Key; use http::Key;
@ -18,7 +19,7 @@ pub enum SessionKey {
} }
impl SessionKey { impl SessionKey {
#[inline] #[inline(always)]
pub fn kind(&self) -> &'static str { pub fn kind(&self) -> &'static str {
match *self { match *self {
SessionKey::Generated(_) => "generated", SessionKey::Generated(_) => "generated",
@ -26,7 +27,7 @@ impl SessionKey {
} }
} }
#[inline] #[inline(always)]
fn inner(&self) -> &Key { fn inner(&self) -> &Key {
match *self { match *self {
SessionKey::Generated(ref key) | SessionKey::Provided(ref key) => key SessionKey::Generated(ref key) | SessionKey::Provided(ref key) => key
@ -34,6 +35,15 @@ impl SessionKey {
} }
} }
#[cfg(feature = "tls")]
pub struct TlsConfig {
pub certs: Vec<Certificate>,
pub key: PrivateKey
}
#[cfg(not(feature = "tls"))]
pub struct TlsConfig;
/// Structure for Rocket application configuration. /// Structure for Rocket application configuration.
/// ///
/// A `Config` structure is typically built using the [build](#method.build) /// A `Config` structure is typically built using the [build](#method.build)
@ -61,20 +71,73 @@ pub struct Config {
pub workers: u16, pub workers: u16,
/// How much information to log. /// How much information to log.
pub log_level: LoggingLevel, pub log_level: LoggingLevel,
/// The session key.
pub(crate) session_key: SessionKey,
/// TLS configuration.
pub(crate) tls: Option<TlsConfig>,
/// Extra parameters that aren't part of Rocket's core config. /// Extra parameters that aren't part of Rocket's core config.
pub extras: HashMap<String, Value>, pub extras: HashMap<String, Value>,
/// The path to the configuration file this config belongs to. /// The path to the configuration file this config belongs to.
pub config_path: PathBuf, pub config_path: PathBuf,
/// The session key.
pub(crate) session_key: SessionKey,
} }
macro_rules! parse { macro_rules! config_from_raw {
($conf:expr, $name:expr, $val:expr, $method:ident, $expect: expr) => ( ($config:expr, $name:expr, $value:expr,
$val.$method().ok_or_else(|| { $($key:ident => ($type:ident, $set:ident, $map:expr)),+ | _ => $rest:expr) => (
$conf.bad_type($name, $val.type_str(), $expect) match $name {
}) $(stringify!($key) => {
); concat_idents!(value_as_, $type)($config, $name, $value)
.and_then(|parsed| $map($config.$set(parsed)))
})+
_ => $rest
}
)
}
#[inline(always)]
fn value_as_str<'a>(config: &Config, name: &str, value: &'a Value) -> Result<&'a str> {
value.as_str().ok_or(config.bad_type(name, value.type_str(), "a string"))
}
#[inline(always)]
fn value_as_u16(config: &Config, name: &str, value: &Value) -> Result<u16> {
match value.as_integer() {
Some(x) if x >= 0 && x <= (u16::max_value() as i64) => Ok(x as u16),
_ => Err(config.bad_type(name, value.type_str(), "a 16-bit unsigned integer"))
}
}
#[inline(always)]
fn value_as_log_level(config: &Config, name: &str, value: &Value) -> Result<LoggingLevel> {
value_as_str(config, name, value)
.and_then(|s| s.parse().map_err(|e| config.bad_type(name, value.type_str(), e)))
}
#[inline(always)]
fn value_as_tls_config<'v>(config: &Config,
name: &str,
value: &'v Value,
) -> Result<(&'v str, &'v str)>
{
let (mut certs_path, mut key_path) = (None, None);
let table = value.as_table()
.ok_or_else(|| config.bad_type(name, value.type_str(), "a table"))?;
let env = config.environment;
for (key, value) in table {
match key.as_str() {
"certs" => certs_path = Some(value_as_str(config, "tls.certs", value)?),
"key" => key_path = Some(value_as_str(config, "tls.key", value)?),
_ => return Err(ConfigError::UnknownKey(format!("{}.tls.{}", env, key)))
}
}
if let (Some(certs), Some(key)) = (certs_path, key_path) {
Ok((certs, key))
} else {
Err(config.bad_type(name, "a table with missing entries",
"a table with `certs` and `key` entries"))
}
} }
impl Config { impl Config {
@ -119,7 +182,7 @@ impl Config {
/// let mut my_config = Config::new(Environment::Production).expect("cwd"); /// let mut my_config = Config::new(Environment::Production).expect("cwd");
/// my_config.set_port(1001); /// my_config.set_port(1001);
/// ``` /// ```
pub fn new(env: Environment) -> config::Result<Config> { pub fn new(env: Environment) -> Result<Config> {
let cwd = env::current_dir().map_err(|_| ConfigError::BadCWD)?; let cwd = env::current_dir().map_err(|_| ConfigError::BadCWD)?;
Config::default(env, cwd.as_path().join("Rocket.custom.toml")) Config::default(env, cwd.as_path().join("Rocket.custom.toml"))
} }
@ -131,7 +194,7 @@ impl Config {
/// # Panics /// # Panics
/// ///
/// Panics if randomness cannot be retrieved from the OS. /// Panics if randomness cannot be retrieved from the OS.
pub(crate) fn default<P>(env: Environment, path: P) -> config::Result<Config> pub(crate) fn default<P>(env: Environment, path: P) -> Result<Config>
where P: AsRef<Path> where P: AsRef<Path>
{ {
let config_path = path.as_ref().to_path_buf(); let config_path = path.as_ref().to_path_buf();
@ -155,6 +218,7 @@ impl Config {
workers: default_workers, workers: default_workers,
log_level: LoggingLevel::Normal, log_level: LoggingLevel::Normal,
session_key: key, session_key: key,
tls: None,
extras: HashMap::new(), extras: HashMap::new(),
config_path: config_path, config_path: config_path,
} }
@ -167,6 +231,7 @@ impl Config {
workers: default_workers, workers: default_workers,
log_level: LoggingLevel::Normal, log_level: LoggingLevel::Normal,
session_key: key, session_key: key,
tls: None,
extras: HashMap::new(), extras: HashMap::new(),
config_path: config_path, config_path: config_path,
} }
@ -179,6 +244,7 @@ impl Config {
workers: default_workers, workers: default_workers,
log_level: LoggingLevel::Critical, log_level: LoggingLevel::Critical,
session_key: key, session_key: key,
tls: None,
extras: HashMap::new(), extras: HashMap::new(),
config_path: config_path, config_path: config_path,
} }
@ -209,40 +275,22 @@ impl Config {
/// * **workers**: Integer (16-bit unsigned) /// * **workers**: Integer (16-bit unsigned)
/// * **log**: String /// * **log**: String
/// * **session_key**: String (192-bit base64) /// * **session_key**: String (192-bit base64)
pub(crate) fn set_raw(&mut self, name: &str, val: &Value) -> config::Result<()> { /// * **tls**: Table (`certs` (path as String), `key` (path as String))
if name == "address" { pub(crate) fn set_raw(&mut self, name: &str, val: &Value) -> Result<()> {
let address_str = parse!(self, name, val, as_str, "a string")?; let (id, ok) = (|val| val, |_| Ok(()));
self.set_address(address_str)?; config_from_raw!(self, name, val,
} else if name == "port" { address => (str, set_address, id),
let port = parse!(self, name, val, as_integer, "an integer")?; port => (u16, set_port, ok),
if port < 0 || port > (u16::max_value() as i64) { workers => (u16, set_workers, ok),
return Err(self.bad_type(name, val.type_str(), "a 16-bit unsigned integer")) session_key => (str, set_session_key, id),
} log => (log_level, set_log_level, ok),
tls => (tls_config, set_raw_tls, id)
self.set_port(port as u16); | _ => {
} else if name == "workers" {
let workers = parse!(self, name, val, as_integer, "an integer")?;
if workers < 0 || workers > (u16::max_value() as i64) {
return Err(self.bad_type(name, val.type_str(), "a 16-bit unsigned integer"));
}
self.set_workers(workers as u16);
} else if name == "session_key" {
let key = parse!(self, name, val, as_str, "a string")?;
self.set_session_key(key)?;
} else if name == "log" {
let level_str = parse!(self, name, val, as_str, "a string")?;
let expect = "log level ('normal', 'critical', 'debug')";
match level_str.parse() {
Ok(level) => self.set_log_level(level),
Err(_) => return Err(self.bad_type(name, val.type_str(), expect))
}
} else {
self.extras.insert(name.into(), val.clone()); self.extras.insert(name.into(), val.clone());
}
Ok(()) Ok(())
} }
)
}
/// Sets the root directory of this configuration to `root`. /// Sets the root directory of this configuration to `root`.
/// ///
@ -286,7 +334,7 @@ impl Config {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
pub fn set_address<A: Into<String>>(&mut self, address: A) -> config::Result<()> { pub fn set_address<A: Into<String>>(&mut self, address: A) -> Result<()> {
let address = address.into(); let address = address.into();
if address.parse::<IpAddr>().is_err() && lookup_host(&address).is_err() { if address.parse::<IpAddr>().is_err() && lookup_host(&address).is_err() {
return Err(self.bad_type("address", "string", "a valid hostname or IP")); return Err(self.bad_type("address", "string", "a valid hostname or IP"));
@ -310,6 +358,7 @@ impl Config {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
#[inline]
pub fn set_port(&mut self, port: u16) { pub fn set_port(&mut self, port: u16) {
self.port = port; self.port = port;
} }
@ -328,6 +377,7 @@ impl Config {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
#[inline]
pub fn set_workers(&mut self, workers: u16) { pub fn set_workers(&mut self, workers: u16) {
self.workers = workers; self.workers = workers;
} }
@ -354,7 +404,7 @@ impl Config {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
pub fn set_session_key<K: Into<String>>(&mut self, key: K) -> config::Result<()> { pub fn set_session_key<K: Into<String>>(&mut self, key: K) -> Result<()> {
let key = key.into(); let key = key.into();
let error = self.bad_type("session_key", "string", let error = self.bad_type("session_key", "string",
"a 256-bit base64 encoded string"); "a 256-bit base64 encoded string");
@ -387,10 +437,42 @@ impl Config {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
#[inline]
pub fn set_log_level(&mut self, log_level: LoggingLevel) { pub fn set_log_level(&mut self, log_level: LoggingLevel) {
self.log_level = log_level; self.log_level = log_level;
} }
#[cfg(feature = "tls")]
pub fn set_tls(&mut self, certs_path: &str, key_path: &str) -> Result<()> {
use hyper_rustls::util as tls;
let err = "nonexistent or invalid file";
let certs = tls::load_certs(certs_path)
.map_err(|_| self.bad_type("tls", err, "a readable certificates file"))?;
let key = tls::load_private_key(key_path)
.map_err(|_| self.bad_type("tls", err, "a readable private key file"))?;
self.tls = Some(TlsConfig { certs, key });
Ok(())
}
#[cfg(not(feature = "tls"))]
pub fn set_tls(&mut self, _: &str, _: &str) -> Result<()> {
self.tls = Some(TlsConfig);
Ok(())
}
#[cfg(not(test))]
#[inline(always)]
fn set_raw_tls(&mut self, paths: (&str, &str)) -> Result<()> {
self.set_tls(paths.0, paths.1)
}
#[cfg(test)]
fn set_raw_tls(&mut self, _: (&str, &str)) -> Result<()> {
Ok(())
}
/// Sets the extras for `self` to be the key/value pairs in `extras`. /// Sets the extras for `self` to be the key/value pairs in `extras`.
/// encoded string. /// encoded string.
/// ///
@ -413,6 +495,7 @@ impl Config {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
#[inline]
pub fn set_extras(&mut self, extras: HashMap<String, Value>) { pub fn set_extras(&mut self, extras: HashMap<String, Value>) {
self.extras = extras; self.extras = extras;
} }
@ -441,6 +524,7 @@ impl Config {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
#[inline]
pub fn extras<'a>(&'a self) -> impl Iterator<Item=(&'a str, &'a Value)> { pub fn extras<'a>(&'a self) -> impl Iterator<Item=(&'a str, &'a Value)> {
self.extras.iter().map(|(k, v)| (k.as_str(), v)) self.extras.iter().map(|(k, v)| (k.as_str(), v))
} }
@ -470,9 +554,9 @@ impl Config {
/// ///
/// assert_eq!(config.get_str("my_extra"), Ok("extra_value")); /// assert_eq!(config.get_str("my_extra"), Ok("extra_value"));
/// ``` /// ```
pub fn get_str<'a>(&'a self, name: &str) -> config::Result<&'a str> { pub fn get_str<'a>(&'a self, name: &str) -> Result<&'a str> {
let value = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; let val = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?;
parse!(self, name, value, as_str, "a string") val.as_str().ok_or_else(|| self.bad_type(name, val.type_str(), "a string"))
} }
/// Attempts to retrieve the extra named `name` as an integer. /// Attempts to retrieve the extra named `name` as an integer.
@ -494,9 +578,9 @@ impl Config {
/// ///
/// assert_eq!(config.get_int("my_extra"), Ok(1025)); /// assert_eq!(config.get_int("my_extra"), Ok(1025));
/// ``` /// ```
pub fn get_int(&self, name: &str) -> config::Result<i64> { pub fn get_int(&self, name: &str) -> Result<i64> {
let value = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; let val = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?;
parse!(self, name, value, as_integer, "an integer") val.as_integer().ok_or_else(|| self.bad_type(name, val.type_str(), "an integer"))
} }
/// Attempts to retrieve the extra named `name` as a boolean. /// Attempts to retrieve the extra named `name` as a boolean.
@ -518,9 +602,9 @@ impl Config {
/// ///
/// assert_eq!(config.get_bool("my_extra"), Ok(true)); /// assert_eq!(config.get_bool("my_extra"), Ok(true));
/// ``` /// ```
pub fn get_bool(&self, name: &str) -> config::Result<bool> { pub fn get_bool(&self, name: &str) -> Result<bool> {
let value = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; let val = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?;
parse!(self, name, value, as_bool, "a boolean") val.as_bool().ok_or_else(|| self.bad_type(name, val.type_str(), "a boolean"))
} }
/// Attempts to retrieve the extra named `name` as a float. /// Attempts to retrieve the extra named `name` as a float.
@ -542,9 +626,9 @@ impl Config {
/// ///
/// assert_eq!(config.get_float("pi"), Ok(3.14159)); /// assert_eq!(config.get_float("pi"), Ok(3.14159));
/// ``` /// ```
pub fn get_float(&self, name: &str) -> config::Result<f64> { pub fn get_float(&self, name: &str) -> Result<f64> {
let value = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; let val = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?;
parse!(self, name, value, as_float, "a float") val.as_float().ok_or_else(|| self.bad_type(name, val.type_str(), "a float"))
} }
/// Attempts to retrieve the extra named `name` as a slice of an array. /// Attempts to retrieve the extra named `name` as a slice of an array.
@ -566,9 +650,9 @@ impl Config {
/// ///
/// assert!(config.get_slice("numbers").is_ok()); /// assert!(config.get_slice("numbers").is_ok());
/// ``` /// ```
pub fn get_slice(&self, name: &str) -> config::Result<&[Value]> { pub fn get_slice(&self, name: &str) -> Result<&[Value]> {
let value = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; let val = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?;
parse!(self, name, value, as_slice, "a slice") val.as_slice().ok_or_else(|| self.bad_type(name, val.type_str(), "a slice"))
} }
/// Attempts to retrieve the extra named `name` as a table. /// Attempts to retrieve the extra named `name` as a table.
@ -594,9 +678,9 @@ impl Config {
/// ///
/// assert!(config.get_table("my_table").is_ok()); /// assert!(config.get_table("my_table").is_ok());
/// ``` /// ```
pub fn get_table(&self, name: &str) -> config::Result<&config::Table> { pub fn get_table(&self, name: &str) -> Result<&Table> {
let value = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?; let val = self.extras.get(name).ok_or_else(|| ConfigError::NotFound)?;
parse!(self, name, value, as_table, "a table") val.as_table().ok_or_else(|| self.bad_type(name, val.type_str(), "a table"))
} }
/// Returns the path at which the configuration file for `self` is stored. /// Returns the path at which the configuration file for `self` is stored.

View File

@ -52,8 +52,12 @@ pub enum ConfigError {
ParseError(String, PathBuf, Vec<ParsingError>), ParseError(String, PathBuf, Vec<ParsingError>),
/// There was a TOML parsing error in a config environment variable. /// There was a TOML parsing error in a config environment variable.
/// ///
/// Parameters: (env_key, env_value, expected type) /// Parameters: (env_key, env_value, error)
BadEnvVal(String, String, &'static str), BadEnvVal(String, String, String),
/// The entry (key) is unknown.
///
/// Parameters: (key)
UnknownKey(String),
} }
impl ConfigError { impl ConfigError {
@ -95,11 +99,14 @@ impl ConfigError {
trace_!("'{}' - {}", error_source, White.paint(&error.desc)); trace_!("'{}' - {}", error_source, White.paint(&error.desc));
} }
} }
BadEnvVal(ref key, ref value, ref expected) => { BadEnvVal(ref key, ref value, ref error) => {
error!("environment variable '{}={}' could not be parsed", error!("environment variable '{}={}' could not be parsed",
White.paint(key), White.paint(value)); White.paint(key), White.paint(value));
info_!("value for {:?} must be {}", info_!("{}", White.paint(error));
White.paint(key), White.paint(expected)) }
UnknownKey(ref key) => {
error!("the configuration key '{}' is unknown and disallowed in \
this position", White.paint(key));
} }
} }
} }
@ -123,6 +130,7 @@ impl fmt::Display for ConfigError {
BadFilePath(ref p, _) => write!(f, "{:?} is not a valid config path", p), BadFilePath(ref p, _) => write!(f, "{:?} is not a valid config path", p),
BadEnv(ref e) => write!(f, "{:?} is not a valid `ROCKET_ENV` value", e), BadEnv(ref e) => write!(f, "{:?} is not a valid `ROCKET_ENV` value", e),
ParseError(..) => write!(f, "the config file contains invalid TOML"), ParseError(..) => write!(f, "the config file contains invalid TOML"),
UnknownKey(ref k) => write!(f, "'{}' is an unknown key", k),
BadEntry(ref e, _) => { BadEntry(ref e, _) => {
write!(f, "{:?} is not a valid `[environment]` entry", e) write!(f, "{:?} is not a valid `[environment]` entry", e)
} }
@ -148,6 +156,7 @@ impl Error for ConfigError {
ParseError(..) => "the config file contains invalid TOML", ParseError(..) => "the config file contains invalid TOML",
BadType(..) => "a key was specified with a value of the wrong type", BadType(..) => "a key was specified with a value of the wrong type",
BadEnvVal(..) => "an environment variable could not be parsed", BadEnvVal(..) => "an environment variable could not be parsed",
UnknownKey(..) => "an unknown key was used in a disallowed position",
} }
} }
} }

View File

@ -43,6 +43,10 @@
//! * **session_key**: _[string]_ a 256-bit base64 encoded string (44 //! * **session_key**: _[string]_ a 256-bit base64 encoded string (44
//! characters) to use as the session key //! characters) to use as the session key
//! * example: `"8Xui8SN4mI+7egV/9dlfYYLGQJeEx4+DwmSQLwDVXJg="` //! * example: `"8Xui8SN4mI+7egV/9dlfYYLGQJeEx4+DwmSQLwDVXJg="`
//! * **tls**: _[dict]_ a dictionary with two keys: 1) `certs`: _[string]_ a
//! path to a certificate chain in PEM format, and 2) `key`: _[string]_ a
//! path to a private key file in PEM format for the certificate in `certs`
//! * example: `{ certs = "/path/to/certs.pem", key = "/path/to/key.pem" }`
//! //!
//! ### Rocket.toml //! ### Rocket.toml
//! //!
@ -118,6 +122,31 @@
//! //!
//! Environment variables take precedence over all other configuration methods: //! Environment variables take precedence over all other configuration methods:
//! if the variable is set, it will be used as the value for the parameter. //! if the variable is set, it will be used as the value for the parameter.
//! Variable values are parsed as if they were TOML syntax. As illustration,
//! consider the following examples:
//!
//! ```sh
//! ROCKET_INTEGER=1
//! ROCKET_FLOAT=3.14
//! ROCKET_STRING=Hello
//! ROCKET_STRING="Hello"
//! ROCKET_BOOL=true
//! ROCKET_ARRAY=[1,"b",3.14]
//! ROCKET_DICT={key="abc",val=123}
//! ```
//!
//! ### TLS Configuration
//!
//! TLS can be enabled by specifying the `tls.key` and `tls.certs` parameters.
//! Rocket must be compiled with the `tls` feature enabled for the parameters to
//! take effect. The recommended way to specify the parameters is via the
//! `global` environment:
//!
//! ```toml
//! [global.tls]
//! certs = "/path/to/certs.pem"
//! key = "/path/to/key.pem"
//! ```
//! //!
//! ## Retrieving Configuration Parameters //! ## Retrieving Configuration Parameters
//! //!
@ -278,6 +307,7 @@ impl RocketConfig {
Err(ConfigError::NotFound) Err(ConfigError::NotFound)
} }
#[inline]
fn get_mut(&mut self, env: Environment) -> &mut Config { fn get_mut(&mut self, env: Environment) -> &mut Config {
match self.config.get_mut(&env) { match self.config.get_mut(&env) {
Some(config) => config, Some(config) => config,
@ -306,33 +336,37 @@ impl RocketConfig {
} }
/// Retrieves the `Config` for the active environment. /// Retrieves the `Config` for the active environment.
#[inline]
pub fn active(&self) -> &Config { pub fn active(&self) -> &Config {
self.get(self.active_env) self.get(self.active_env)
} }
// Override all environments with values from env variables if present. // Override all environments with values from env variables if present.
fn override_from_env(&mut self) -> Result<()> { fn override_from_env(&mut self) -> Result<()> {
'outer: for (env_key, env_val) in env::vars() { for (key, val) in env::vars() {
if env_key.len() < ENV_VAR_PREFIX.len() { if key.len() < ENV_VAR_PREFIX.len() {
continue continue
} else if !uncased_eq(&env_key[..ENV_VAR_PREFIX.len()], ENV_VAR_PREFIX) { } else if !uncased_eq(&key[..ENV_VAR_PREFIX.len()], ENV_VAR_PREFIX) {
continue continue
} }
// Skip environment variables that are handled elsewhere. // Skip environment variables that are handled elsewhere.
for prehandled_var in PREHANDLED_VARS.iter() { if PREHANDLED_VARS.iter().any(|var| uncased_eq(&key, var)) {
if uncased_eq(&env_key, &prehandled_var) { continue
continue 'outer
}
} }
// Parse the key and value and try to set the variable for all envs. // Parse the key and value and try to set the variable for all envs.
let key = env_key[ENV_VAR_PREFIX.len()..].to_lowercase(); let key = key[ENV_VAR_PREFIX.len()..].to_lowercase();
let val = parse_simple_toml_value(&env_val); let toml_val = match parse_simple_toml_value(&val) {
Ok(val) => val,
Err(e) => return Err(ConfigError::BadEnvVal(key, val, e.into()))
};
for env in &Environment::all() { for env in &Environment::all() {
match self.get_mut(*env).set_raw(&key, &val) { match self.get_mut(*env).set_raw(&key, &toml_val) {
Err(ConfigError::BadType(_, exp, _, _)) => { Err(ConfigError::BadType(_, exp, actual, _)) => {
return Err(ConfigError::BadEnvVal(env_key, env_val, exp)) let e = format!("expected {}, but found {}", exp, actual);
return Err(ConfigError::BadEnvVal(key, val, e))
} }
Err(e) => return Err(e), Err(e) => return Err(e),
Ok(_) => { /* move along */ } Ok(_) => { /* move along */ }
@ -458,7 +492,7 @@ unsafe fn private_init() {
let config = RocketConfig::read().unwrap_or_else(|e| { let config = RocketConfig::read().unwrap_or_else(|e| {
match e { match e {
ParseError(..) | BadEntry(..) | BadEnv(..) | BadType(..) ParseError(..) | BadEntry(..) | BadEnv(..) | BadType(..)
| BadFilePath(..) | BadEnvVal(..) => bail(e), | BadFilePath(..) | BadEnvVal(..) | UnknownKey(..) => bail(e),
IOError | BadCWD => warn!("Failed reading Rocket.toml. Using defaults."), IOError | BadCWD => warn!("Failed reading Rocket.toml. Using defaults."),
NotFound => { /* try using the default below */ } NotFound => { /* try using the default below */ }
} }
@ -697,6 +731,64 @@ mod test {
"#.to_string(), TEST_CONFIG_FILENAME).is_err()); "#.to_string(), TEST_CONFIG_FILENAME).is_err());
} }
// Only do this test when the tls feature is disabled since the file paths
// we're supplying don't actually exist.
#[test]
fn test_good_tls_values() {
// Take the lock so changing the environment doesn't cause races.
let _env_lock = ENV_LOCK.lock().unwrap();
env::set_var(CONFIG_ENV, "dev");
assert!(RocketConfig::parse(r#"
[staging]
tls = { certs = "some/path.pem", key = "some/key.pem" }
"#.to_string(), TEST_CONFIG_FILENAME).is_ok());
assert!(RocketConfig::parse(r#"
[staging.tls]
certs = "some/path.pem"
key = "some/key.pem"
"#.to_string(), TEST_CONFIG_FILENAME).is_ok());
assert!(RocketConfig::parse(r#"
[global.tls]
certs = "some/path.pem"
key = "some/key.pem"
"#.to_string(), TEST_CONFIG_FILENAME).is_ok());
assert!(RocketConfig::parse(r#"
[global]
tls = { certs = "some/path.pem", key = "some/key.pem" }
"#.to_string(), TEST_CONFIG_FILENAME).is_ok());
}
#[test]
fn test_bad_tls_config() {
// Take the lock so changing the environment doesn't cause races.
let _env_lock = ENV_LOCK.lock().unwrap();
env::remove_var(CONFIG_ENV);
assert!(RocketConfig::parse(r#"
[development]
tls = "hello"
"#.to_string(), TEST_CONFIG_FILENAME).is_err());
assert!(RocketConfig::parse(r#"
[development]
tls = { certs = "some/path.pem" }
"#.to_string(), TEST_CONFIG_FILENAME).is_err());
assert!(RocketConfig::parse(r#"
[development]
tls = { certs = "some/path.pem", key = "some/key.pem", extra = "bah" }
"#.to_string(), TEST_CONFIG_FILENAME).is_err());
assert!(RocketConfig::parse(r#"
[staging]
tls = { cert = "some/path.pem", key = "some/key.pem" }
"#.to_string(), TEST_CONFIG_FILENAME).is_err());
}
#[test] #[test]
fn test_good_port_values() { fn test_good_port_values() {
// Take the lock so changing the environment doesn't cause races. // Take the lock so changing the environment doesn't cause races.

View File

@ -4,20 +4,63 @@ use std::str::FromStr;
use config::Value; use config::Value;
pub fn parse_simple_toml_value(string: &str) -> Value { pub fn parse_simple_toml_value(string: &str) -> Result<Value, &'static str> {
if let Ok(int) = i64::from_str(string) { if string.is_empty() {
return Value::Integer(int) return Err("value is empty")
} }
if let Ok(boolean) = bool::from_str(string) { let value = if let Ok(int) = i64::from_str(string) {
return Value::Boolean(boolean) Value::Integer(int)
} else if let Ok(float) = f64::from_str(string) {
Value::Float(float)
} else if let Ok(boolean) = bool::from_str(string) {
Value::Boolean(boolean)
} else if string.starts_with('{') {
if !string.ends_with('}') {
return Err("value is missing closing '}'")
} }
if let Ok(float) = f64::from_str(string) { let mut table = BTreeMap::new();
return Value::Float(float) let inner = &string[1..string.len() - 1].trim();
if !inner.is_empty() {
for key_val in inner.split(',') {
let (key, val) = match key_val.find('=') {
Some(i) => (&key_val[..i], &key_val[(i + 1)..]),
None => return Err("missing '=' in dicitonary key/value pair")
};
let key = key.trim().to_string();
let val = parse_simple_toml_value(val.trim())?;
table.insert(key, val);
}
} }
Value::Table(table)
} else if string.starts_with('[') {
if !string.ends_with(']') {
return Err("value is missing closing ']'")
}
let mut vals = vec![];
let inner = &string[1..string.len() - 1].trim();
if !inner.is_empty() {
for val_str in inner.split(',') {
vals.push(parse_simple_toml_value(val_str.trim())?);
}
}
Value::Array(vals)
} else if string.starts_with('"') {
if !string[1..].ends_with('"') {
return Err("value is missing closing '\"'");
}
Value::String(string[1..string.len() - 1].to_string())
} else {
Value::String(string.to_string()) Value::String(string.to_string())
};
Ok(value)
} }
/// Conversion trait from standard types into TOML `Value`s. /// Conversion trait from standard types into TOML `Value`s.
@ -27,18 +70,21 @@ pub trait IntoValue {
} }
impl<'a> IntoValue for &'a str { impl<'a> IntoValue for &'a str {
#[inline(always)]
fn into_value(self) -> Value { fn into_value(self) -> Value {
Value::String(self.to_string()) Value::String(self.to_string())
} }
} }
impl IntoValue for Value { impl IntoValue for Value {
#[inline(always)]
fn into_value(self) -> Value { fn into_value(self) -> Value {
self self
} }
} }
impl<V: IntoValue> IntoValue for Vec<V> { impl<V: IntoValue> IntoValue for Vec<V> {
#[inline(always)]
fn into_value(self) -> Value { fn into_value(self) -> Value {
Value::Array(self.into_iter().map(|v| v.into_value()).collect()) Value::Array(self.into_iter().map(|v| v.into_value()).collect())
} }
@ -87,3 +133,50 @@ impl_into_value!(Boolean: bool);
impl_into_value!(Float: f64); impl_into_value!(Float: f64);
impl_into_value!(Float: f32, as f64); impl_into_value!(Float: f32, as f64);
#[cfg(test)]
mod test {
use std::collections::BTreeMap;
use super::parse_simple_toml_value;
use super::IntoValue;
use super::Value::*;
macro_rules! assert_parse {
($string:expr, $value:expr) => (
match parse_simple_toml_value($string) {
Ok(value) => assert_eq!(value, $value),
Err(e) => panic!("{:?} failed to parse: {:?}", $string, e)
};
)
}
#[test]
fn parse_toml_values() {
assert_parse!("1", Integer(1));
assert_parse!("1.32", Float(1.32));
assert_parse!("true", Boolean(true));
assert_parse!("false", Boolean(false));
assert_parse!("hello, WORLD!", String("hello, WORLD!".into()));
assert_parse!("hi", String("hi".into()));
assert_parse!("\"hi\"", String("hi".into()));
assert_parse!("[]", Array(Vec::new()));
assert_parse!("[1]", vec![1].into_value());
assert_parse!("[1, 2, 3]", vec![1, 2, 3].into_value());
assert_parse!("[1.32, 2]",
vec![1.32.into_value(), 2.into_value()].into_value());
assert_parse!("{}", Table(BTreeMap::new()));
assert_parse!("{a=b}", Table({
let mut map = BTreeMap::new();
map.insert("a".into(), "b".into_value());
map
}));
assert_parse!("{v=1, on=true,pi=3.14}", Table({
let mut map = BTreeMap::new();
map.insert("v".into(), 1.into_value());
map.insert("on".into(), true.into_value());
map.insert("pi".into(), 3.14.into_value());
map
}));
}
}

View File

@ -4,9 +4,10 @@ use std::fs::File;
use std::time::Duration; use std::time::Duration;
use std::mem::transmute; use std::mem::transmute;
use super::data_stream::{DataStream, StreamReader, kill_stream}; #[cfg(feature = "tls")] use hyper_rustls::WrappedStream;
use ext::ReadExt; use ext::ReadExt;
use super::data_stream::{DataStream, HyperNetStream, StreamReader, kill_stream};
use http::hyper::h1::HttpReader; use http::hyper::h1::HttpReader;
use http::hyper::buffer; use http::hyper::buffer;
@ -82,32 +83,47 @@ impl Data {
DataStream::new(stream, network) DataStream::new(stream, network)
} }
// FIXME: This is absolutely terrible (downcasting!), thanks to Hyper.
pub(crate) fn from_hyp(mut h_body: BodyReader) -> Result<Data, &'static str> { pub(crate) fn from_hyp(mut h_body: BodyReader) -> Result<Data, &'static str> {
// FIXME: This is asolutely terrible, thanks to Hyper. // Create the Data object from hyper's buffer.
let (vec, pos, cap) = h_body.get_mut().take_buf();
let net_stream = h_body.get_ref().get_ref();
#[cfg(feature = "tls")]
fn concrete_stream(stream: &&mut NetworkStream) -> Option<HyperNetStream> {
stream.downcast_ref::<HttpStream>()
.map(|s| HyperNetStream::Http(s.clone()))
.or_else(|| {
stream.downcast_ref::<WrappedStream>()
.map(|s| HyperNetStream::Https(s.clone()))
})
}
#[cfg(not(feature = "tls"))]
fn concrete_stream(stream: &&mut NetworkStream) -> Option<HyperNetStream> {
stream.downcast_ref::<HttpStream>()
.map(|s| HyperNetStream::Http(s.clone()))
}
// Retrieve the underlying HTTPStream from Hyper. // Retrieve the underlying HTTPStream from Hyper.
let mut stream = match h_body.get_ref().get_ref() let stream = match concrete_stream(net_stream) {
.downcast_ref::<HttpStream>() { Some(stream) => stream,
Some(s) => { None => return Err("Stream is not an HTTP(s) stream!")
let owned_stream = s.clone();
let buf_len = h_body.get_ref().get_buf().len() as u64;
match h_body {
SizedReader(_, n) => SizedReader(owned_stream, n - buf_len),
EofReader(_) => EofReader(owned_stream),
EmptyReader(_) => EmptyReader(owned_stream),
ChunkedReader(_, n) =>
ChunkedReader(owned_stream, n.map(|k| k - buf_len)),
}
},
None => return Err("Stream is not an HTTP stream!"),
}; };
// Set the read timeout to 5 seconds. // Set the read timeout to 5 seconds.
stream.get_mut().set_read_timeout(Some(Duration::from_secs(5))).unwrap(); stream.set_read_timeout(Some(Duration::from_secs(5))).expect("timeout set");
// Create the Data object from hyper's buffer. // Create a reader from the stream. Don't read what's already buffered.
let (vec, pos, cap) = h_body.get_mut().take_buf(); let buffered = (cap - pos) as u64;
Ok(Data::new(vec, pos, cap, stream)) let reader = match h_body {
SizedReader(_, n) => SizedReader(stream, n - buffered),
EofReader(_) => EofReader(stream),
EmptyReader(_) => EmptyReader(stream),
ChunkedReader(_, n) => ChunkedReader(stream, n.map(|k| k - buffered)),
};
Ok(Data::new(vec, pos, cap, reader))
} }
/// Retrieve the `peek` buffer. /// Retrieve the `peek` buffer.
@ -153,8 +169,8 @@ impl Data {
pub(crate) fn new(mut buf: Vec<u8>, pub(crate) fn new(mut buf: Vec<u8>,
pos: usize, pos: usize,
mut cap: usize, mut cap: usize,
mut stream: StreamReader) mut stream: StreamReader
-> Data { ) -> Data {
// Make sure the buffer is large enough for the bytes we want to peek. // Make sure the buffer is large enough for the bytes we want to peek.
const PEEK_BYTES: usize = 4096; const PEEK_BYTES: usize = 4096;
if buf.len() < PEEK_BYTES { if buf.len() < PEEK_BYTES {
@ -198,4 +214,3 @@ impl Drop for Data {
} }
} }
} }

View File

@ -1,12 +1,80 @@
use std::io::{self, BufRead, Read, Cursor, BufReader, Chain, Take}; use std::io::{self, BufRead, Read, Cursor, BufReader, Chain, Take};
use std::net::Shutdown; use std::net::{SocketAddr, Shutdown};
use std::time::Duration;
#[cfg(feature = "tls")] use hyper_rustls::WrappedStream as RustlsStream;
use http::hyper::net::{HttpStream, NetworkStream}; use http::hyper::net::{HttpStream, NetworkStream};
use http::hyper::h1::HttpReader; use http::hyper::h1::HttpReader;
pub type StreamReader = HttpReader<HttpStream>; pub type StreamReader = HttpReader<HyperNetStream>;
pub type InnerStream = Chain<Take<Cursor<Vec<u8>>>, BufReader<StreamReader>>; pub type InnerStream = Chain<Take<Cursor<Vec<u8>>>, BufReader<StreamReader>>;
#[derive(Clone)]
pub enum HyperNetStream {
Http(HttpStream),
#[cfg(feature = "tls")]
Https(RustlsStream)
}
macro_rules! with_inner {
($net:expr, |$stream:ident| $body:expr) => ({
trace!("{}:{}", file!(), line!());
match *$net {
HyperNetStream::Http(ref $stream) => $body,
#[cfg(feature = "tls")] HyperNetStream::Https(ref $stream) => $body
}
});
($net:expr, |mut $stream:ident| $body:expr) => ({
trace!("{}:{}", file!(), line!());
match *$net {
HyperNetStream::Http(ref mut $stream) => $body,
#[cfg(feature = "tls")] HyperNetStream::Https(ref mut $stream) => $body
}
})
}
impl io::Read for HyperNetStream {
#[inline(always)]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
with_inner!(self, |mut stream| io::Read::read(stream, buf))
}
}
impl io::Write for HyperNetStream {
#[inline(always)]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
with_inner!(self, |mut stream| io::Write::write(stream, buf))
}
#[inline(always)]
fn flush(&mut self) -> io::Result<()> {
with_inner!(self, |mut stream| io::Write::flush(stream))
}
}
impl NetworkStream for HyperNetStream {
#[inline(always)]
fn peer_addr(&mut self) -> io::Result<SocketAddr> {
with_inner!(self, |mut stream| NetworkStream::peer_addr(stream))
}
#[inline(always)]
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
with_inner!(self, |stream| NetworkStream::set_read_timeout(stream, dur))
}
#[inline(always)]
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
with_inner!(self, |stream| NetworkStream::set_write_timeout(stream, dur))
}
#[inline(always)]
fn close(&mut self, how: Shutdown) -> io::Result<()> {
with_inner!(self, |mut stream| NetworkStream::close(stream, how))
}
}
/// Raw data stream of a request body. /// Raw data stream of a request body.
/// ///
/// This stream can only be obtained by calling /// This stream can only be obtained by calling
@ -15,33 +83,38 @@ pub type InnerStream = Chain<Take<Cursor<Vec<u8>>>, BufReader<StreamReader>>;
/// Instead, it must be used as an opaque `Read` or `BufRead` structure. /// Instead, it must be used as an opaque `Read` or `BufRead` structure.
pub struct DataStream { pub struct DataStream {
stream: InnerStream, stream: InnerStream,
network: HttpStream, network: HyperNetStream,
} }
impl DataStream { impl DataStream {
pub(crate) fn new(stream: InnerStream, network: HttpStream) -> DataStream { #[inline(always)]
DataStream { stream: stream, network: network, } pub(crate) fn new(stream: InnerStream, network: HyperNetStream) -> DataStream {
DataStream { stream, network }
} }
} }
impl Read for DataStream { impl Read for DataStream {
#[inline(always)]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.stream.read(buf) self.stream.read(buf)
} }
} }
impl BufRead for DataStream { impl BufRead for DataStream {
#[inline(always)]
fn fill_buf(&mut self) -> io::Result<&[u8]> { fn fill_buf(&mut self) -> io::Result<&[u8]> {
self.stream.fill_buf() self.stream.fill_buf()
} }
#[inline(always)]
fn consume(&mut self, amt: usize) { fn consume(&mut self, amt: usize) {
self.stream.consume(amt) self.stream.consume(amt)
} }
} }
// pub fn kill_stream<S: Read>(stream: &mut S, network: &mut HyperNetStream) {
pub fn kill_stream<S: Read, N: NetworkStream>(stream: &mut S, network: &mut N) { pub fn kill_stream<S: Read, N: NetworkStream>(stream: &mut S, network: &mut N) {
io::copy(&mut stream.take(1024), &mut io::sink()).expect("sink"); io::copy(&mut stream.take(1024), &mut io::sink()).expect("kill_stream: sink");
// If there are any more bytes, kill it. // If there are any more bytes, kill it.
let mut buf = [0]; let mut buf = [0];
@ -61,4 +134,3 @@ impl Drop for DataStream {
kill_stream(&mut self.stream, &mut self.network); kill_stream(&mut self.stream, &mut self.network);
} }
} }

View File

@ -111,6 +111,7 @@ impl LaunchError {
} }
impl From<hyper::Error> for LaunchError { impl From<hyper::Error> for LaunchError {
#[inline]
fn from(error: hyper::Error) -> LaunchError { fn from(error: hyper::Error) -> LaunchError {
match error { match error {
hyper::Error::Io(e) => LaunchError::new(LaunchErrorKind::Io(e)), hyper::Error::Io(e) => LaunchError::new(LaunchErrorKind::Io(e)),
@ -120,6 +121,7 @@ impl From<hyper::Error> for LaunchError {
} }
impl fmt::Display for LaunchErrorKind { impl fmt::Display for LaunchErrorKind {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
LaunchErrorKind::Io(ref e) => write!(f, "I/O error: {}", e), LaunchErrorKind::Io(ref e) => write!(f, "I/O error: {}", e),
@ -129,6 +131,7 @@ impl fmt::Display for LaunchErrorKind {
} }
impl fmt::Debug for LaunchError { impl fmt::Debug for LaunchError {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.mark_handled(); self.mark_handled();
write!(f, "{:?}", self.kind()) write!(f, "{:?}", self.kind())
@ -136,6 +139,7 @@ impl fmt::Debug for LaunchError {
} }
impl fmt::Display for LaunchError { impl fmt::Display for LaunchError {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.mark_handled(); self.mark_handled();
write!(f, "{}", self.kind()) write!(f, "{}", self.kind())
@ -143,6 +147,7 @@ impl fmt::Display for LaunchError {
} }
impl ::std::error::Error for LaunchError { impl ::std::error::Error for LaunchError {
#[inline]
fn description(&self) -> &str { fn description(&self) -> &str {
self.mark_handled(); self.mark_handled();
match *self.kind() { match *self.kind() {

View File

@ -7,6 +7,7 @@
#![feature(lookup_host)] #![feature(lookup_host)]
#![feature(plugin)] #![feature(plugin)]
#![feature(never_type)] #![feature(never_type)]
#![feature(concat_idents)]
#![plugin(pear_codegen)] #![plugin(pear_codegen)]
@ -99,6 +100,8 @@
#[macro_use] extern crate log; #[macro_use] extern crate log;
#[macro_use] extern crate pear; #[macro_use] extern crate pear;
#[cfg(feature = "tls")] extern crate rustls;
#[cfg(feature = "tls")] extern crate hyper_rustls;
extern crate term_painter; extern crate term_painter;
extern crate hyper; extern crate hyper;
extern crate url; extern crate url;

View File

@ -32,13 +32,14 @@ impl LoggingLevel {
} }
impl FromStr for LoggingLevel { impl FromStr for LoggingLevel {
type Err = (); type Err = &'static str;
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
let level = match s { let level = match s {
"critical" => LoggingLevel::Critical, "critical" => LoggingLevel::Critical,
"normal" => LoggingLevel::Normal, "normal" => LoggingLevel::Normal,
"debug" => LoggingLevel::Debug, "debug" => LoggingLevel::Debug,
_ => return Err(()) _ => return Err("a log level (debug, normal, critical)")
}; };
Ok(level) Ok(level)
@ -87,9 +88,10 @@ impl Log for RocketLogger {
return; return;
} }
// Don't print Hyper's messages unless Debug is enabled. // Don't print Hyper or Rustls messages unless debug is enabled.
let from_hyper = record.location().module_path().starts_with("hyper::"); let from_hyper = record.location().module_path().starts_with("hyper::");
if from_hyper && self.0 != LoggingLevel::Debug { let from_rustls = record.location().module_path().starts_with("rustls::");
if self.0 != LoggingLevel::Debug && (from_hyper || from_rustls) {
return; return;
} }

View File

@ -6,9 +6,9 @@ use std::io::{self, Write};
use term_painter::Color::*; use term_painter::Color::*;
use term_painter::ToStyle; use term_painter::ToStyle;
use state::Container; use state::Container;
#[cfg(feature = "tls")] use hyper_rustls::TlsServer;
use {logger, handler}; use {logger, handler};
use ext::ReadExt; use ext::ReadExt;
use config::{self, Config}; use config::{self, Config};
@ -74,6 +74,35 @@ impl hyper::Handler for Rocket {
} }
} }
// This macro is a terrible hack to get around Hyper's Server<L> type. What we
// want is to use almost exactly the same launch code when we're serving over
// HTTPS as over HTTP. But Hyper forces two different types, so we can't use the
// same code, at least not trivially. These macros get around that by passing in
// the same code as a continuation in `$continue`. This wouldn't work as a
// regular function taking in a closure because the types of the inputs to the
// closure would be different depending on whether TLS was enabled or not.
#[cfg(not(feature = "tls"))]
macro_rules! serve {
($rocket:expr, $addr:expr, |$server:ident, $proto:ident| $continue:expr) => ({
let ($proto, $server) = ("http://", hyper::Server::http($addr));
$continue
})
}
#[cfg(feature = "tls")]
macro_rules! serve {
($rocket:expr, $addr:expr, |$server:ident, $proto:ident| $continue:expr) => ({
if let Some(ref tls) = $rocket.config.tls {
let tls = TlsServer::new(tls.certs.clone(), tls.key.clone());
let ($proto, $server) = ("https://", hyper::Server::https($addr, tls));
$continue
} else {
let ($proto, $server) = ("http://", hyper::Server::http($addr));
$continue
}
})
}
impl Rocket { impl Rocket {
#[inline] #[inline]
fn issue_response(&self, mut response: Response, hyp_res: hyper::FreshResponse) { fn issue_response(&self, mut response: Response, hyp_res: hyper::FreshResponse) {
@ -262,7 +291,11 @@ impl Rocket {
Outcome::Forward(data) Outcome::Forward(data)
} }
// TODO: DOC. // Finds the error catcher for the status `status` and executes it fo the
// given request `req`. If a user has registere a catcher for `status`, the
// catcher is called. If the catcher fails to return a good response, the
// 500 catcher is executed. if there is no registered catcher for `status`,
// the default catcher is used.
fn handle_error<'r>(&self, status: Status, req: &'r Request) -> Response<'r> { fn handle_error<'r>(&self, status: Status, req: &'r Request) -> Response<'r> {
warn_!("Responding with {} catcher.", Red.paint(&status)); warn_!("Responding with {} catcher.", Red.paint(&status));
@ -350,6 +383,16 @@ impl Rocket {
info_!("workers: {}", White.paint(config.workers)); info_!("workers: {}", White.paint(config.workers));
info_!("session key: {}", White.paint(config.session_key.kind())); info_!("session key: {}", White.paint(config.session_key.kind()));
let tls_configured = config.tls.is_some();
if tls_configured && cfg!(feature = "tls") {
info_!("tls: {}", White.paint("enabled"));
} else {
info_!("tls: {}", White.paint("disabled"));
if tls_configured {
warn_!("tls is configured, but the tls feature is disabled");
}
}
for (name, value) in config.extras() { for (name, value) in config.extras() {
info_!("{} {}: {}", Yellow.paint("[extra]"), name, White.paint(value)); info_!("{} {}: {}", Yellow.paint("[extra]"), name, White.paint(value));
} }
@ -553,14 +596,15 @@ impl Rocket {
} }
let full_addr = format!("{}:{}", self.config.address, self.config.port); let full_addr = format!("{}:{}", self.config.address, self.config.port);
let server = match hyper::Server::http(full_addr.as_str()) { serve!(self, &full_addr, |server, proto| {
Ok(hyper_server) => hyper_server, let server = match server {
Ok(server) => server,
Err(e) => return LaunchError::from(e) Err(e) => return LaunchError::from(e)
}; };
info!("🚀 {} {}{}", info!("🚀 {} {}{}",
White.paint("Rocket has launched from"), White.paint("Rocket has launched from"),
White.bold().paint("http://"), White.bold().paint(proto),
White.bold().paint(&full_addr)); White.bold().paint(&full_addr));
let threads = self.config.workers as usize; let threads = self.config.workers as usize;
@ -569,6 +613,7 @@ impl Rocket {
} }
unreachable!("the call to `handle_threads` should block on success") unreachable!("the call to `handle_threads` should block on success")
})
} }
/// Retrieves all of the mounted routes. /// Retrieves all of the mounted routes.