diff --git a/contrib/lib/src/databases/config.rs b/contrib/lib/src/databases/config.rs index e86699e9..01c00524 100644 --- a/contrib/lib/src/databases/config.rs +++ b/contrib/lib/src/databases/config.rs @@ -78,12 +78,29 @@ impl Config { /// # } /// ``` pub fn from(db_name: &str, rocket: &rocket::Rocket) -> Result { + Config::figment(db_name, rocket).extract::() + } + + /// Returns a `Figment` focused on the configuration for the database with + /// name `db_name`. + /// + /// # Example + /// + /// ```rust + /// use rocket::Rocket; + /// use rocket_contrib::databases::Config; + /// + /// fn pool(rocket: &Rocket) { + /// let my_db_figment = Config::figment("my_db", rocket); + /// let mysql_prod_figment = Config::figment("mysql_prod", rocket); + /// } + /// ``` + pub fn figment(db_name: &str, rocket: &rocket::Rocket) -> Figment { let db_key = format!("databases.{}", db_name); let key = |name: &str| format!("{}.{}", db_key, name); Figment::from(rocket.figment()) - .merge(Serialized::default(&key("pool_size"), rocket.config().workers * 2)) - .merge(Serialized::default(&key("timeout"), 5)) - .extract_inner::(&db_key) + .join(Serialized::default(&key("pool_size"), rocket.config().workers * 2)) + .join(Serialized::default(&key("timeout"), 5)) + .focus(&db_key) } } - diff --git a/contrib/lib/src/databases/poolable.rs b/contrib/lib/src/databases/poolable.rs index 5b78002a..3dcaf3a2 100644 --- a/contrib/lib/src/databases/poolable.rs +++ b/contrib/lib/src/databases/poolable.rs @@ -110,9 +110,32 @@ impl Poolable for diesel::SqliteConnection { type Error = std::convert::Infallible; fn pool(db_name: &str, rocket: &rocket::Rocket) -> PoolResult { + use diesel::{SqliteConnection, connection::SimpleConnection}; + use diesel::r2d2::{CustomizeConnection, ConnectionManager, Error, Pool}; + + #[derive(Debug)] + struct Customizer; + + impl CustomizeConnection for Customizer { + fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), Error> { + conn.batch_execute("\ + PRAGMA journal_mode = WAL;\ + PRAGMA busy_timeout = 1000;\ + PRAGMA foreign_keys = ON;\ + ").map_err(Error::QueryError)?; + + Ok(()) + } + } + let config = Config::from(db_name, rocket)?; - let manager = diesel::r2d2::ConnectionManager::new(&config.url); - Ok(r2d2::Pool::builder().max_size(config.pool_size).build(manager)?) + let manager = ConnectionManager::new(&config.url); + let pool = Pool::builder() + .connection_customizer(Box::new(Customizer)) + .max_size(config.pool_size) + .build(manager)?; + + Ok(pool) } } @@ -160,8 +183,50 @@ impl Poolable for rusqlite::Connection { type Error = std::convert::Infallible; fn pool(db_name: &str, rocket: &rocket::Rocket) -> PoolResult { - let config = Config::from(db_name, rocket)?; - let manager = r2d2_sqlite::SqliteConnectionManager::file(&*config.url); + use rocket::figment::providers::Serialized; + + #[derive(Debug, serde::Deserialize, serde::Serialize)] + #[serde(rename_all = "snake_case")] + enum OpenFlag { + ReadOnly, + ReadWrite, + Create, + Uri, + Memory, + NoMutex, + FullMutex, + SharedCache, + PrivateCache, + Nofollow, + } + + let figment = Config::figment(db_name, rocket); + let config: Config = figment.extract()?; + let open_flags: Vec = figment + .join(Serialized::default("open_flags", >::new())) + .extract_inner("open_flags")?; + + let mut flags = rusqlite::OpenFlags::default(); + for flag in open_flags { + let sql_flag = match flag { + OpenFlag::ReadOnly => rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY, + OpenFlag::ReadWrite => rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE, + OpenFlag::Create => rusqlite::OpenFlags::SQLITE_OPEN_CREATE, + OpenFlag::Uri => rusqlite::OpenFlags::SQLITE_OPEN_URI, + OpenFlag::Memory => rusqlite::OpenFlags::SQLITE_OPEN_MEMORY, + OpenFlag::NoMutex => rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX, + OpenFlag::FullMutex => rusqlite::OpenFlags::SQLITE_OPEN_FULL_MUTEX, + OpenFlag::SharedCache => rusqlite::OpenFlags::SQLITE_OPEN_SHARED_CACHE, + OpenFlag::PrivateCache => rusqlite::OpenFlags::SQLITE_OPEN_PRIVATE_CACHE, + OpenFlag::Nofollow => rusqlite::OpenFlags::SQLITE_OPEN_NOFOLLOW, + }; + + flags.insert(sql_flag) + }; + + let manager = r2d2_sqlite::SqliteConnectionManager::file(&*config.url) + .with_flags(flags); + Ok(r2d2::Pool::builder().max_size(config.pool_size).build(manager)?) } } diff --git a/contrib/lib/tests/databases.rs b/contrib/lib/tests/databases.rs index 04e3613a..517a2318 100644 --- a/contrib/lib/tests/databases.rs +++ b/contrib/lib/tests/databases.rs @@ -29,7 +29,7 @@ mod rusqlite_integration_test { use rocket::figment::{Figment, util::map}; let options = map!["url" => ":memory:"]; - let config = Figment::from(rocket::Config::default()) + let config = Figment::from(rocket::Config::debug_default()) .merge(("databases", map!["test_db" => &options])) .merge(("databases", map!["test_db_2" => &options])); @@ -100,7 +100,7 @@ mod drop_runtime_test { async fn test_drop_runtime() { use rocket::figment::{Figment, util::map}; - let config = Figment::from(rocket::Config::default()) + let config = Figment::from(rocket::Config::debug_default()) .merge(("databases", map!["test_db" => map!["url" => ""]])); let rocket = rocket::custom(config).attach(TestDb::fairing()); diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index 8d73ae9d..2c7ed2f6 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -37,7 +37,7 @@ atomic = "0.5" parking_lot = "0.11" ubyte = {version = "0.10", features = ["serde"] } serde = { version = "1.0", features = ["derive"] } -figment = { version = "0.10.2", features = ["toml", "env"] } +figment = { version = "0.10.4", features = ["toml", "env"] } rand = "0.8" either = "1" pin-project-lite = "0.2"