(sync_db_pools) postgres tls

This commit is contained in:
Edwin Svensson 2024-01-13 23:27:52 +01:00
parent b3abc760ae
commit bc981ea333
No known key found for this signature in database
GPG Key ID: 7F9EC4DD0C67951F
5 changed files with 271 additions and 4 deletions

View File

@ -16,6 +16,7 @@ diesel_postgres_pool = ["diesel/postgres", "diesel/r2d2"]
diesel_mysql_pool = ["diesel/mysql", "diesel/r2d2"]
sqlite_pool = ["rusqlite", "r2d2_sqlite"]
postgres_pool = ["postgres", "r2d2_postgres"]
postgres_pool_tls = ["postgres_pool", "dep:postgres-native-tls", "dep:native-tls"]
memcache_pool = ["memcache", "r2d2-memcache"]
[dependencies]
@ -27,6 +28,8 @@ diesel = { version = "2.0.0", default-features = false, optional = true }
postgres = { version = "0.19", optional = true }
r2d2_postgres = { version = "0.18", optional = true }
postgres-native-tls = { version = "0.5", optional = true }
native-tls = { version = "0.2", optional = true }
rusqlite = { version = "0.29.0", optional = true }
r2d2_sqlite = { version = "0.22.0", optional = true }

View File

@ -1,3 +1,5 @@
use std::path::PathBuf;
use rocket::{Rocket, Build};
use rocket::figment::{self, Figment, providers::Serialized};
@ -39,6 +41,33 @@ pub struct Config {
/// Defaults to `5`.
// FIXME: Use `time`.
pub timeout: u8,
/// TLS configuration.
pub tls: Option<TlsConfig>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct TlsConfig {
/// Allow TLS connections with invalid certificates.
///
/// _Default:_ `false`.
pub accept_invalid_certs: bool,
/// Allow TLS connections with invalid hostnames.
///
/// _Default:_ `false`.
pub accept_invalid_hostnames: bool,
/// Sets the name of a file containing SSL certificate authority (CA) certificate(s).
/// If the file exists, the servers certificate will be verified to be signed by one of these authorities.
///
/// _Default:_ `None`.
pub ssl_root_cert: Option<PathBuf>,
/// Sets the name of a file containing SSL client certificate.
///
/// _Default:_ `None`.
pub ssl_client_cert: Option<PathBuf>,
/// Sets the name of a file containing SSL client key.
///
/// _Default:_ `None`.
pub ssl_client_key: Option<PathBuf>,
}
impl Config {
@ -107,10 +136,18 @@ impl Config {
.map(|workers| workers * 4)
.ok();
let figment = Figment::from(rocket.figment())
let mut figment = Figment::from(rocket.figment())
.focus(&db_key)
.join(Serialized::default("timeout", 5));
if figment.find_value("tls").is_ok() {
figment = figment.join(Serialized::default("tls.accept_invalid_certs", false))
.join(Serialized::default("tls.accept_invalid_hostnames", false))
.join(Serialized::default("tls.ssl_root_cert", None::<PathBuf>))
.join(Serialized::default("tls.ssl_client_cert", None::<PathBuf>))
.join(Serialized::default("tls.ssl_client_key", None::<PathBuf>));
}
match default_pool_size {
Some(pool_size) => figment.join(Serialized::default("pool_size", pool_size)),
None => figment

View File

@ -88,6 +88,8 @@ impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
Err(Error::Config(e)) => dberr!("config", db, "{}", e, rocket),
Err(Error::Pool(e)) => dberr!("pool init", db, "{}", e, rocket),
Err(Error::Custom(e)) => dberr!("pool manager", db, "{:?}", e, rocket),
Err(Error::Io(e)) => dberr!("io", db, "{:?}", e, rocket),
Err(Error::Tls(e)) => dberr!("tls", db, "{:?}", e, rocket),
}
}).await
})

View File

@ -14,6 +14,10 @@ pub enum Error<T> {
Pool(r2d2::Error),
/// An error occurred while extracting a `figment` configuration.
Config(figment::Error),
/// An IO error occurred.
Io(std::io::Error),
/// A TLS error occurred.
Tls(Box<dyn std::error::Error>),
}
impl<T> From<figment::Error> for Error<T> {
@ -27,3 +31,9 @@ impl<T> From<r2d2::Error> for Error<T> {
Error::Pool(error)
}
}
impl<T> From<std::io::Error> for Error<T> {
fn from(error: std::io::Error) -> Self {
Error::Io(error)
}
}

View File

@ -184,16 +184,231 @@ impl Poolable for diesel::MysqlConnection {
}
}
// TODO: Add a feature to enable TLS in `postgres`; parse a suitable `config`.
#[cfg(feature = "postgres_pool")]
pub mod pg {
use std::pin::Pin;
use std::task::{Context, Poll};
use std::io;
#[derive(Clone)]
pub enum MaybeTlsConnector {
NoTls(postgres::tls::NoTls),
#[cfg(feature = "postgres_pool_tls")]
Tls(postgres_native_tls::MakeTlsConnector)
}
impl postgres::tls::MakeTlsConnect<postgres::Socket> for MaybeTlsConnector {
type Stream = MaybeTlsConnector_Stream;
type TlsConnect = MaybeTlsConnector_TlsConnect;
type Error = MaybeTlsConnector_Error;
fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error> {
match self {
MaybeTlsConnector::NoTls(connector) => {
<postgres::tls::NoTls as postgres::tls::MakeTlsConnect<postgres::Socket>>::make_tls_connect(connector, domain)
.map(Self::TlsConnect::NoTls)
.map_err(Self::Error::NoTls)
},
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector::Tls(connector) => {
<postgres_native_tls::MakeTlsConnector as postgres::tls::MakeTlsConnect<postgres::Socket>>::make_tls_connect(connector, domain)
.map(Self::TlsConnect::Tls)
.map_err(Self::Error::Tls)
},
}
}
}
// --- Stream ---
#[allow(non_camel_case_types)]
pub enum MaybeTlsConnector_Stream {
NoTls(postgres::tls::NoTlsStream),
#[cfg(feature = "postgres_pool_tls")]
Tls(postgres_native_tls::TlsStream<postgres::Socket>)
}
impl postgres::tls::TlsStream for MaybeTlsConnector_Stream {
fn channel_binding(&self) -> postgres::tls::ChannelBinding {
match self {
MaybeTlsConnector_Stream::NoTls(stream) => stream.channel_binding(),
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_Stream::Tls(stream) => stream.channel_binding(),
}
}
}
impl tokio::io::AsyncRead for MaybeTlsConnector_Stream {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<Result<(), io::Error>> {
match *self {
MaybeTlsConnector_Stream::NoTls(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_Stream::Tls(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl tokio::io::AsyncWrite for MaybeTlsConnector_Stream {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
match *self {
MaybeTlsConnector_Stream::NoTls(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_Stream::Tls(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match *self {
MaybeTlsConnector_Stream::NoTls(ref mut stream) => Pin::new(stream).poll_flush(cx),
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_Stream::Tls(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match *self {
MaybeTlsConnector_Stream::NoTls(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_Stream::Tls(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
// --- TlsConnect ---
#[allow(non_camel_case_types)]
pub enum MaybeTlsConnector_TlsConnect {
NoTls(postgres::tls::NoTls),
#[cfg(feature = "postgres_pool_tls")]
Tls(postgres_native_tls::TlsConnector)
}
impl postgres::tls::TlsConnect<postgres::Socket> for MaybeTlsConnector_TlsConnect {
type Stream = MaybeTlsConnector_Stream;
type Error = MaybeTlsConnector_Error;
type Future = MaybeTlsConnector_Future;
fn connect(self, socket: postgres::Socket) -> Self::Future {
match self {
MaybeTlsConnector_TlsConnect::NoTls(connector) => Self::Future::NoTls(connector.connect(socket)),
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_TlsConnect::Tls(connector) => Self::Future::Tls(connector.connect(socket)),
}
}
}
// --- Error ---
#[allow(non_camel_case_types)]
#[derive(Debug)]
pub enum MaybeTlsConnector_Error {
NoTls(postgres::tls::NoTlsError),
#[cfg(feature = "postgres_pool_tls")]
Tls(native_tls::Error)
}
impl std::fmt::Display for MaybeTlsConnector_Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MaybeTlsConnector_Error::NoTls(e) => e.fmt(f),
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_Error::Tls(e) => e.fmt(f),
}
}
}
impl std::error::Error for MaybeTlsConnector_Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
MaybeTlsConnector_Error::NoTls(e) => e.source(),
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_Error::Tls(e) => e.source(),
}
}
}
// --- Future ---
#[allow(non_camel_case_types)]
pub enum MaybeTlsConnector_Future {
NoTls(postgres::tls::NoTlsFuture),
#[cfg(feature = "postgres_pool_tls")]
Tls(<postgres_native_tls::TlsConnector as postgres::tls::TlsConnect<postgres::Socket>>::Future)
}
impl std::future::Future for MaybeTlsConnector_Future {
type Output = Result<MaybeTlsConnector_Stream, MaybeTlsConnector_Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match *self {
MaybeTlsConnector_Future::NoTls(ref mut future) => Pin::new(future).poll(cx).map(|v| v.map(MaybeTlsConnector_Stream::NoTls)).map_err(MaybeTlsConnector_Error::NoTls),
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_Future::Tls(ref mut future) => Pin::new(future).poll(cx).map(|v| v.map(MaybeTlsConnector_Stream::Tls)).map_err(MaybeTlsConnector_Error::Tls),
}
}
}
}
#[cfg(feature = "postgres_pool")]
impl Poolable for postgres::Client {
type Manager = r2d2_postgres::PostgresConnectionManager<postgres::tls::NoTls>;
type Manager = r2d2_postgres::PostgresConnectionManager<pg::MaybeTlsConnector>;
type Error = postgres::Error;
fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
let config = Config::from(db_name, rocket)?;
let url = config.url.parse().map_err(Error::Custom)?;
let manager = r2d2_postgres::PostgresConnectionManager::new(url, postgres::tls::NoTls);
let tls_connector = match config.tls {
#[allow(unused_variables)] // `tls_config` is unused when `postgres_pool_tls` is disabled.
Some(ref tls_config) => {
#[cfg(feature = "postgres_pool_tls")]
{
let mut connector_builder = native_tls::TlsConnector::builder();
if let Some(ref cert) = tls_config.ssl_root_cert {
let cert_file_bytes = std::fs::read(cert)?;
let cert = native_tls::Certificate::from_pem(&cert_file_bytes).map_err(|e| Error::Tls(e.into()))?;
connector_builder.add_root_certificate(cert);
// Client certs
match (
tls_config.ssl_client_cert.as_ref(),
tls_config.ssl_client_key.as_ref(),
) {
(Some(cert), Some(key)) => {
let cert_file_bytes = std::fs::read(cert)?;
let key_file_bytes = std::fs::read(key)?;
let cert = native_tls::Identity::from_pkcs8(&cert_file_bytes, &key_file_bytes).map_err(|e| Error::Tls(e.into()))?;
connector_builder.identity(cert);
},
(Some(_), None) => {
return Err(Error::Tls("Client certificate provided without client key".into()))
},
(None, Some(_)) => {
return Err(Error::Tls("Client key provided without client certificate".into()))
},
(None, None) => {},
}
}
connector_builder.danger_accept_invalid_certs(tls_config.accept_invalid_certs);
connector_builder.danger_accept_invalid_hostnames(tls_config.accept_invalid_hostnames);
let connector = connector_builder.build().map_err(|e| Error::Tls(e.into()))?;
pg::MaybeTlsConnector::Tls(postgres_native_tls::MakeTlsConnector::new(connector))
}
#[cfg(not(feature = "postgres_pool_tls"))]
{
rocket::warn!("TLS is not enabled for the `postgres_pool` feature. Postgres TLS configuration will be ignored. Enable the `postgres_pool_tls` feature to enable TLS.");
pg::MaybeTlsConnector::NoTls(postgres::tls::NoTls)
}
},
None => {
pg::MaybeTlsConnector::NoTls(postgres::tls::NoTls)
}
};
let manager = r2d2_postgres::PostgresConnectionManager::new(url, tls_connector);
let pool = r2d2::Pool::builder()
.max_size(config.pool_size)
.connection_timeout(Duration::from_secs(config.timeout as u64))