Set better default 'diesel::SQLite' options.

The options set WAL, a 1s busy timeout, and enables foreign keys.

This also adds a focused 'databases::Config::figment()', used to
retrieve a focused figment for a given config.
This commit is contained in:
Sergio Benitez 2021-04-07 23:06:44 -07:00
parent 5f568599a9
commit cfd5af38fe
4 changed files with 93 additions and 11 deletions

View File

@ -78,12 +78,29 @@ impl Config {
/// # } /// # }
/// ``` /// ```
pub fn from(db_name: &str, rocket: &rocket::Rocket) -> Result<Config, figment::Error> { pub fn from(db_name: &str, rocket: &rocket::Rocket) -> Result<Config, figment::Error> {
Config::figment(db_name, rocket).extract::<Self>()
}
/// Returns a `Figment` focused on the configuration for the database with
/// name `db_name`.
///
/// # Example
///
/// ```rust
/// use rocket::Rocket;
/// use rocket_contrib::databases::Config;
///
/// fn pool(rocket: &Rocket) {
/// let my_db_figment = Config::figment("my_db", rocket);
/// let mysql_prod_figment = Config::figment("mysql_prod", rocket);
/// }
/// ```
pub fn figment(db_name: &str, rocket: &rocket::Rocket) -> Figment {
let db_key = format!("databases.{}", db_name); let db_key = format!("databases.{}", db_name);
let key = |name: &str| format!("{}.{}", db_key, name); let key = |name: &str| format!("{}.{}", db_key, name);
Figment::from(rocket.figment()) Figment::from(rocket.figment())
.merge(Serialized::default(&key("pool_size"), rocket.config().workers * 2)) .join(Serialized::default(&key("pool_size"), rocket.config().workers * 2))
.merge(Serialized::default(&key("timeout"), 5)) .join(Serialized::default(&key("timeout"), 5))
.extract_inner::<Self>(&db_key) .focus(&db_key)
} }
} }

View File

@ -110,9 +110,32 @@ impl Poolable for diesel::SqliteConnection {
type Error = std::convert::Infallible; type Error = std::convert::Infallible;
fn pool(db_name: &str, rocket: &rocket::Rocket) -> PoolResult<Self> { fn pool(db_name: &str, rocket: &rocket::Rocket) -> PoolResult<Self> {
use diesel::{SqliteConnection, connection::SimpleConnection};
use diesel::r2d2::{CustomizeConnection, ConnectionManager, Error, Pool};
#[derive(Debug)]
struct Customizer;
impl CustomizeConnection<SqliteConnection, Error> for Customizer {
fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), Error> {
conn.batch_execute("\
PRAGMA journal_mode = WAL;\
PRAGMA busy_timeout = 1000;\
PRAGMA foreign_keys = ON;\
").map_err(Error::QueryError)?;
Ok(())
}
}
let config = Config::from(db_name, rocket)?; let config = Config::from(db_name, rocket)?;
let manager = diesel::r2d2::ConnectionManager::new(&config.url); let manager = ConnectionManager::new(&config.url);
Ok(r2d2::Pool::builder().max_size(config.pool_size).build(manager)?) let pool = Pool::builder()
.connection_customizer(Box::new(Customizer))
.max_size(config.pool_size)
.build(manager)?;
Ok(pool)
} }
} }
@ -160,8 +183,50 @@ impl Poolable for rusqlite::Connection {
type Error = std::convert::Infallible; type Error = std::convert::Infallible;
fn pool(db_name: &str, rocket: &rocket::Rocket) -> PoolResult<Self> { fn pool(db_name: &str, rocket: &rocket::Rocket) -> PoolResult<Self> {
let config = Config::from(db_name, rocket)?; use rocket::figment::providers::Serialized;
let manager = r2d2_sqlite::SqliteConnectionManager::file(&*config.url);
#[derive(Debug, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "snake_case")]
enum OpenFlag {
ReadOnly,
ReadWrite,
Create,
Uri,
Memory,
NoMutex,
FullMutex,
SharedCache,
PrivateCache,
Nofollow,
}
let figment = Config::figment(db_name, rocket);
let config: Config = figment.extract()?;
let open_flags: Vec<OpenFlag> = figment
.join(Serialized::default("open_flags", <Vec<OpenFlag>>::new()))
.extract_inner("open_flags")?;
let mut flags = rusqlite::OpenFlags::default();
for flag in open_flags {
let sql_flag = match flag {
OpenFlag::ReadOnly => rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY,
OpenFlag::ReadWrite => rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE,
OpenFlag::Create => rusqlite::OpenFlags::SQLITE_OPEN_CREATE,
OpenFlag::Uri => rusqlite::OpenFlags::SQLITE_OPEN_URI,
OpenFlag::Memory => rusqlite::OpenFlags::SQLITE_OPEN_MEMORY,
OpenFlag::NoMutex => rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX,
OpenFlag::FullMutex => rusqlite::OpenFlags::SQLITE_OPEN_FULL_MUTEX,
OpenFlag::SharedCache => rusqlite::OpenFlags::SQLITE_OPEN_SHARED_CACHE,
OpenFlag::PrivateCache => rusqlite::OpenFlags::SQLITE_OPEN_PRIVATE_CACHE,
OpenFlag::Nofollow => rusqlite::OpenFlags::SQLITE_OPEN_NOFOLLOW,
};
flags.insert(sql_flag)
};
let manager = r2d2_sqlite::SqliteConnectionManager::file(&*config.url)
.with_flags(flags);
Ok(r2d2::Pool::builder().max_size(config.pool_size).build(manager)?) Ok(r2d2::Pool::builder().max_size(config.pool_size).build(manager)?)
} }
} }

View File

@ -29,7 +29,7 @@ mod rusqlite_integration_test {
use rocket::figment::{Figment, util::map}; use rocket::figment::{Figment, util::map};
let options = map!["url" => ":memory:"]; let options = map!["url" => ":memory:"];
let config = Figment::from(rocket::Config::default()) let config = Figment::from(rocket::Config::debug_default())
.merge(("databases", map!["test_db" => &options])) .merge(("databases", map!["test_db" => &options]))
.merge(("databases", map!["test_db_2" => &options])); .merge(("databases", map!["test_db_2" => &options]));
@ -100,7 +100,7 @@ mod drop_runtime_test {
async fn test_drop_runtime() { async fn test_drop_runtime() {
use rocket::figment::{Figment, util::map}; use rocket::figment::{Figment, util::map};
let config = Figment::from(rocket::Config::default()) let config = Figment::from(rocket::Config::debug_default())
.merge(("databases", map!["test_db" => map!["url" => ""]])); .merge(("databases", map!["test_db" => map!["url" => ""]]));
let rocket = rocket::custom(config).attach(TestDb::fairing()); let rocket = rocket::custom(config).attach(TestDb::fairing());

View File

@ -37,7 +37,7 @@ atomic = "0.5"
parking_lot = "0.11" parking_lot = "0.11"
ubyte = {version = "0.10", features = ["serde"] } ubyte = {version = "0.10", features = ["serde"] }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
figment = { version = "0.10.2", features = ["toml", "env"] } figment = { version = "0.10.4", features = ["toml", "env"] }
rand = "0.8" rand = "0.8"
either = "1" either = "1"
pin-project-lite = "0.2" pin-project-lite = "0.2"