mirror of https://github.com/rwf2/Rocket.git
(sync_db_pools) postgres tls
This commit is contained in:
parent
b3abc760ae
commit
bc981ea333
|
@ -16,6 +16,7 @@ diesel_postgres_pool = ["diesel/postgres", "diesel/r2d2"]
|
|||
diesel_mysql_pool = ["diesel/mysql", "diesel/r2d2"]
|
||||
sqlite_pool = ["rusqlite", "r2d2_sqlite"]
|
||||
postgres_pool = ["postgres", "r2d2_postgres"]
|
||||
postgres_pool_tls = ["postgres_pool", "dep:postgres-native-tls", "dep:native-tls"]
|
||||
memcache_pool = ["memcache", "r2d2-memcache"]
|
||||
|
||||
[dependencies]
|
||||
|
@ -27,6 +28,8 @@ diesel = { version = "2.0.0", default-features = false, optional = true }
|
|||
|
||||
postgres = { version = "0.19", optional = true }
|
||||
r2d2_postgres = { version = "0.18", optional = true }
|
||||
postgres-native-tls = { version = "0.5", optional = true }
|
||||
native-tls = { version = "0.2", optional = true }
|
||||
|
||||
rusqlite = { version = "0.29.0", optional = true }
|
||||
r2d2_sqlite = { version = "0.22.0", optional = true }
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use std::path::PathBuf;
|
||||
|
||||
use rocket::{Rocket, Build};
|
||||
use rocket::figment::{self, Figment, providers::Serialized};
|
||||
|
||||
|
@ -39,6 +41,33 @@ pub struct Config {
|
|||
/// Defaults to `5`.
|
||||
// FIXME: Use `time`.
|
||||
pub timeout: u8,
|
||||
/// TLS configuration.
|
||||
pub tls: Option<TlsConfig>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct TlsConfig {
|
||||
/// Allow TLS connections with invalid certificates.
|
||||
///
|
||||
/// _Default:_ `false`.
|
||||
pub accept_invalid_certs: bool,
|
||||
/// Allow TLS connections with invalid hostnames.
|
||||
///
|
||||
/// _Default:_ `false`.
|
||||
pub accept_invalid_hostnames: bool,
|
||||
/// Sets the name of a file containing SSL certificate authority (CA) certificate(s).
|
||||
/// If the file exists, the server’s certificate will be verified to be signed by one of these authorities.
|
||||
///
|
||||
/// _Default:_ `None`.
|
||||
pub ssl_root_cert: Option<PathBuf>,
|
||||
/// Sets the name of a file containing SSL client certificate.
|
||||
///
|
||||
/// _Default:_ `None`.
|
||||
pub ssl_client_cert: Option<PathBuf>,
|
||||
/// Sets the name of a file containing SSL client key.
|
||||
///
|
||||
/// _Default:_ `None`.
|
||||
pub ssl_client_key: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
|
@ -107,10 +136,18 @@ impl Config {
|
|||
.map(|workers| workers * 4)
|
||||
.ok();
|
||||
|
||||
let figment = Figment::from(rocket.figment())
|
||||
let mut figment = Figment::from(rocket.figment())
|
||||
.focus(&db_key)
|
||||
.join(Serialized::default("timeout", 5));
|
||||
|
||||
if figment.find_value("tls").is_ok() {
|
||||
figment = figment.join(Serialized::default("tls.accept_invalid_certs", false))
|
||||
.join(Serialized::default("tls.accept_invalid_hostnames", false))
|
||||
.join(Serialized::default("tls.ssl_root_cert", None::<PathBuf>))
|
||||
.join(Serialized::default("tls.ssl_client_cert", None::<PathBuf>))
|
||||
.join(Serialized::default("tls.ssl_client_key", None::<PathBuf>));
|
||||
}
|
||||
|
||||
match default_pool_size {
|
||||
Some(pool_size) => figment.join(Serialized::default("pool_size", pool_size)),
|
||||
None => figment
|
||||
|
|
|
@ -88,6 +88,8 @@ impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
|
|||
Err(Error::Config(e)) => dberr!("config", db, "{}", e, rocket),
|
||||
Err(Error::Pool(e)) => dberr!("pool init", db, "{}", e, rocket),
|
||||
Err(Error::Custom(e)) => dberr!("pool manager", db, "{:?}", e, rocket),
|
||||
Err(Error::Io(e)) => dberr!("io", db, "{:?}", e, rocket),
|
||||
Err(Error::Tls(e)) => dberr!("tls", db, "{:?}", e, rocket),
|
||||
}
|
||||
}).await
|
||||
})
|
||||
|
|
|
@ -14,6 +14,10 @@ pub enum Error<T> {
|
|||
Pool(r2d2::Error),
|
||||
/// An error occurred while extracting a `figment` configuration.
|
||||
Config(figment::Error),
|
||||
/// An IO error occurred.
|
||||
Io(std::io::Error),
|
||||
/// A TLS error occurred.
|
||||
Tls(Box<dyn std::error::Error>),
|
||||
}
|
||||
|
||||
impl<T> From<figment::Error> for Error<T> {
|
||||
|
@ -27,3 +31,9 @@ impl<T> From<r2d2::Error> for Error<T> {
|
|||
Error::Pool(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<std::io::Error> for Error<T> {
|
||||
fn from(error: std::io::Error) -> Self {
|
||||
Error::Io(error)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -184,16 +184,231 @@ impl Poolable for diesel::MysqlConnection {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: Add a feature to enable TLS in `postgres`; parse a suitable `config`.
|
||||
#[cfg(feature = "postgres_pool")]
|
||||
pub mod pg {
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::io;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum MaybeTlsConnector {
|
||||
NoTls(postgres::tls::NoTls),
|
||||
#[cfg(feature = "postgres_pool_tls")]
|
||||
Tls(postgres_native_tls::MakeTlsConnector)
|
||||
}
|
||||
|
||||
impl postgres::tls::MakeTlsConnect<postgres::Socket> for MaybeTlsConnector {
|
||||
type Stream = MaybeTlsConnector_Stream;
|
||||
type TlsConnect = MaybeTlsConnector_TlsConnect;
|
||||
type Error = MaybeTlsConnector_Error;
|
||||
|
||||
fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error> {
|
||||
match self {
|
||||
MaybeTlsConnector::NoTls(connector) => {
|
||||
<postgres::tls::NoTls as postgres::tls::MakeTlsConnect<postgres::Socket>>::make_tls_connect(connector, domain)
|
||||
.map(Self::TlsConnect::NoTls)
|
||||
.map_err(Self::Error::NoTls)
|
||||
},
|
||||
#[cfg(feature = "postgres_pool_tls")]
|
||||
MaybeTlsConnector::Tls(connector) => {
|
||||
<postgres_native_tls::MakeTlsConnector as postgres::tls::MakeTlsConnect<postgres::Socket>>::make_tls_connect(connector, domain)
|
||||
.map(Self::TlsConnect::Tls)
|
||||
.map_err(Self::Error::Tls)
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Stream ---
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
pub enum MaybeTlsConnector_Stream {
|
||||
NoTls(postgres::tls::NoTlsStream),
|
||||
#[cfg(feature = "postgres_pool_tls")]
|
||||
Tls(postgres_native_tls::TlsStream<postgres::Socket>)
|
||||
}
|
||||
|
||||
impl postgres::tls::TlsStream for MaybeTlsConnector_Stream {
|
||||
fn channel_binding(&self) -> postgres::tls::ChannelBinding {
|
||||
match self {
|
||||
MaybeTlsConnector_Stream::NoTls(stream) => stream.channel_binding(),
|
||||
#[cfg(feature = "postgres_pool_tls")]
|
||||
MaybeTlsConnector_Stream::Tls(stream) => stream.channel_binding(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl tokio::io::AsyncRead for MaybeTlsConnector_Stream {
|
||||
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<Result<(), io::Error>> {
|
||||
match *self {
|
||||
MaybeTlsConnector_Stream::NoTls(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
#[cfg(feature = "postgres_pool_tls")]
|
||||
MaybeTlsConnector_Stream::Tls(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl tokio::io::AsyncWrite for MaybeTlsConnector_Stream {
|
||||
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
|
||||
match *self {
|
||||
MaybeTlsConnector_Stream::NoTls(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
#[cfg(feature = "postgres_pool_tls")]
|
||||
MaybeTlsConnector_Stream::Tls(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
match *self {
|
||||
MaybeTlsConnector_Stream::NoTls(ref mut stream) => Pin::new(stream).poll_flush(cx),
|
||||
#[cfg(feature = "postgres_pool_tls")]
|
||||
MaybeTlsConnector_Stream::Tls(ref mut stream) => Pin::new(stream).poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
match *self {
|
||||
MaybeTlsConnector_Stream::NoTls(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
#[cfg(feature = "postgres_pool_tls")]
|
||||
MaybeTlsConnector_Stream::Tls(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- TlsConnect ---
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
pub enum MaybeTlsConnector_TlsConnect {
|
||||
NoTls(postgres::tls::NoTls),
|
||||
#[cfg(feature = "postgres_pool_tls")]
|
||||
Tls(postgres_native_tls::TlsConnector)
|
||||
}
|
||||
|
||||
impl postgres::tls::TlsConnect<postgres::Socket> for MaybeTlsConnector_TlsConnect {
|
||||
type Stream = MaybeTlsConnector_Stream;
|
||||
type Error = MaybeTlsConnector_Error;
|
||||
type Future = MaybeTlsConnector_Future;
|
||||
|
||||
fn connect(self, socket: postgres::Socket) -> Self::Future {
|
||||
match self {
|
||||
MaybeTlsConnector_TlsConnect::NoTls(connector) => Self::Future::NoTls(connector.connect(socket)),
|
||||
#[cfg(feature = "postgres_pool_tls")]
|
||||
MaybeTlsConnector_TlsConnect::Tls(connector) => Self::Future::Tls(connector.connect(socket)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Error ---
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
#[derive(Debug)]
|
||||
pub enum MaybeTlsConnector_Error {
|
||||
NoTls(postgres::tls::NoTlsError),
|
||||
#[cfg(feature = "postgres_pool_tls")]
|
||||
Tls(native_tls::Error)
|
||||
}
|
||||
|
||||
impl std::fmt::Display for MaybeTlsConnector_Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
MaybeTlsConnector_Error::NoTls(e) => e.fmt(f),
|
||||
#[cfg(feature = "postgres_pool_tls")]
|
||||
MaybeTlsConnector_Error::Tls(e) => e.fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for MaybeTlsConnector_Error {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
MaybeTlsConnector_Error::NoTls(e) => e.source(),
|
||||
#[cfg(feature = "postgres_pool_tls")]
|
||||
MaybeTlsConnector_Error::Tls(e) => e.source(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Future ---
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
pub enum MaybeTlsConnector_Future {
|
||||
NoTls(postgres::tls::NoTlsFuture),
|
||||
#[cfg(feature = "postgres_pool_tls")]
|
||||
Tls(<postgres_native_tls::TlsConnector as postgres::tls::TlsConnect<postgres::Socket>>::Future)
|
||||
}
|
||||
|
||||
impl std::future::Future for MaybeTlsConnector_Future {
|
||||
type Output = Result<MaybeTlsConnector_Stream, MaybeTlsConnector_Error>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match *self {
|
||||
MaybeTlsConnector_Future::NoTls(ref mut future) => Pin::new(future).poll(cx).map(|v| v.map(MaybeTlsConnector_Stream::NoTls)).map_err(MaybeTlsConnector_Error::NoTls),
|
||||
#[cfg(feature = "postgres_pool_tls")]
|
||||
MaybeTlsConnector_Future::Tls(ref mut future) => Pin::new(future).poll(cx).map(|v| v.map(MaybeTlsConnector_Stream::Tls)).map_err(MaybeTlsConnector_Error::Tls),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "postgres_pool")]
|
||||
impl Poolable for postgres::Client {
|
||||
type Manager = r2d2_postgres::PostgresConnectionManager<postgres::tls::NoTls>;
|
||||
type Manager = r2d2_postgres::PostgresConnectionManager<pg::MaybeTlsConnector>;
|
||||
type Error = postgres::Error;
|
||||
|
||||
fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
|
||||
let config = Config::from(db_name, rocket)?;
|
||||
let url = config.url.parse().map_err(Error::Custom)?;
|
||||
let manager = r2d2_postgres::PostgresConnectionManager::new(url, postgres::tls::NoTls);
|
||||
|
||||
let tls_connector = match config.tls {
|
||||
#[allow(unused_variables)] // `tls_config` is unused when `postgres_pool_tls` is disabled.
|
||||
Some(ref tls_config) => {
|
||||
|
||||
#[cfg(feature = "postgres_pool_tls")]
|
||||
{
|
||||
let mut connector_builder = native_tls::TlsConnector::builder();
|
||||
if let Some(ref cert) = tls_config.ssl_root_cert {
|
||||
let cert_file_bytes = std::fs::read(cert)?;
|
||||
let cert = native_tls::Certificate::from_pem(&cert_file_bytes).map_err(|e| Error::Tls(e.into()))?;
|
||||
connector_builder.add_root_certificate(cert);
|
||||
|
||||
// Client certs
|
||||
match (
|
||||
tls_config.ssl_client_cert.as_ref(),
|
||||
tls_config.ssl_client_key.as_ref(),
|
||||
) {
|
||||
(Some(cert), Some(key)) => {
|
||||
let cert_file_bytes = std::fs::read(cert)?;
|
||||
let key_file_bytes = std::fs::read(key)?;
|
||||
let cert = native_tls::Identity::from_pkcs8(&cert_file_bytes, &key_file_bytes).map_err(|e| Error::Tls(e.into()))?;
|
||||
connector_builder.identity(cert);
|
||||
},
|
||||
(Some(_), None) => {
|
||||
return Err(Error::Tls("Client certificate provided without client key".into()))
|
||||
},
|
||||
(None, Some(_)) => {
|
||||
return Err(Error::Tls("Client key provided without client certificate".into()))
|
||||
},
|
||||
(None, None) => {},
|
||||
}
|
||||
}
|
||||
|
||||
connector_builder.danger_accept_invalid_certs(tls_config.accept_invalid_certs);
|
||||
connector_builder.danger_accept_invalid_hostnames(tls_config.accept_invalid_hostnames);
|
||||
let connector = connector_builder.build().map_err(|e| Error::Tls(e.into()))?;
|
||||
pg::MaybeTlsConnector::Tls(postgres_native_tls::MakeTlsConnector::new(connector))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "postgres_pool_tls"))]
|
||||
{
|
||||
rocket::warn!("TLS is not enabled for the `postgres_pool` feature. Postgres TLS configuration will be ignored. Enable the `postgres_pool_tls` feature to enable TLS.");
|
||||
pg::MaybeTlsConnector::NoTls(postgres::tls::NoTls)
|
||||
}
|
||||
},
|
||||
None => {
|
||||
pg::MaybeTlsConnector::NoTls(postgres::tls::NoTls)
|
||||
}
|
||||
};
|
||||
|
||||
let manager = r2d2_postgres::PostgresConnectionManager::new(url, tls_connector);
|
||||
let pool = r2d2::Pool::builder()
|
||||
.max_size(config.pool_size)
|
||||
.connection_timeout(Duration::from_secs(config.timeout as u64))
|
||||
|
|
Loading…
Reference in New Issue