Support sqlx_sqlite extensions in db_pools.

Resolves #2762.
This commit is contained in:
Sergio Benitez 2024-05-01 20:04:00 -07:00
parent fb39c02212
commit a811a1810d
4 changed files with 26 additions and 5 deletions

View File

@ -15,11 +15,14 @@ use rocket::serde::{Deserialize, Serialize};
/// [default.databases.db_name] /// [default.databases.db_name]
/// url = "/path/to/db.sqlite" /// url = "/path/to/db.sqlite"
/// ///
/// # only `url` is required. `Initializer` provides defaults for the rest. /// # Only `url` is required. These have sane defaults and are optional.
/// min_connections = 64 /// min_connections = 64
/// max_connections = 1024 /// max_connections = 1024
/// connect_timeout = 5 /// connect_timeout = 5
/// idle_timeout = 120 /// idle_timeout = 120
///
/// # This option is only supported by the `sqlx_sqlite` driver.
/// extensions = ["memvfs", "rot13"]
/// ``` /// ```
/// ///
/// Alternatively, a custom provider can be used. For example, a custom `Figment` /// Alternatively, a custom provider can be used. For example, a custom `Figment`
@ -36,6 +39,7 @@ use rocket::serde::{Deserialize, Serialize};
/// max_connections: 1024, /// max_connections: 1024,
/// connect_timeout: 3, /// connect_timeout: 3,
/// idle_timeout: None, /// idle_timeout: None,
/// extensions: None,
/// })); /// }));
/// ///
/// rocket::custom(figment) /// rocket::custom(figment)
@ -45,7 +49,8 @@ use rocket::serde::{Deserialize, Serialize};
/// For general information on configuration in Rocket, see [`rocket::config`]. /// For general information on configuration in Rocket, see [`rocket::config`].
/// For higher-level details on configuring a database, see the [crate-level /// For higher-level details on configuring a database, see the [crate-level
/// docs](crate#configuration). /// docs](crate#configuration).
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] // NOTE: Defaults provided by the figment created in the `Initializer` fairing.
#[derive(Default, Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(crate = "rocket::serde")] #[serde(crate = "rocket::serde")]
pub struct Config { pub struct Config {
/// Database-specific connection and configuration URL. /// Database-specific connection and configuration URL.
@ -80,4 +85,11 @@ pub struct Config {
/// ///
/// _Default:_ `None`. /// _Default:_ `None`.
pub idle_timeout: Option<u64>, pub idle_timeout: Option<u64>,
/// A list of database extensions to load at run-time.
///
/// **Note:** Only the `sqlx_sqlite` driver supports this option (for SQLite
/// extensions) at this time. All other drivers ignore this option.
///
/// _Default:_ `None`.
pub extensions: Option<Vec<String>>,
} }

View File

@ -261,8 +261,8 @@ impl<D: Database> Fairing for Initializer<D> {
let figment = rocket.figment() let figment = rocket.figment()
.focus(&format!("databases.{}", D::NAME)) .focus(&format!("databases.{}", D::NAME))
.merge(Serialized::default("max_connections", workers * 4)) .join(Serialized::default("max_connections", workers * 4))
.merge(Serialized::default("connect_timeout", 5)); .join(Serialized::default("connect_timeout", 5));
match <D::Pool>::init(&figment).await { match <D::Pool>::init(&figment).await {
Ok(pool) => Ok(rocket.manage(D::from(pool))), Ok(pool) => Ok(rocket.manage(D::from(pool))),

View File

@ -180,11 +180,14 @@
//! [default.databases.db_name] //! [default.databases.db_name]
//! url = "db.sqlite" //! url = "db.sqlite"
//! //!
//! # only `url` is required. the rest have defaults and are thus optional //! # Only `url` is required. These have sane defaults and are optional.
//! min_connections = 64 //! min_connections = 64
//! max_connections = 1024 //! max_connections = 1024
//! connect_timeout = 5 //! connect_timeout = 5
//! idle_timeout = 120 //! idle_timeout = 120
//!
//! # This option is only supported by the `sqlx_sqlite` driver.
//! extensions = ["memvfs", "rot13"]
//! ``` //! ```
//! //!
//! Or via environment variables: //! Or via environment variables:

View File

@ -281,6 +281,12 @@ mod sqlx {
*o = std::mem::take(o) *o = std::mem::take(o)
.busy_timeout(Duration::from_secs(__config.connect_timeout)) .busy_timeout(Duration::from_secs(__config.connect_timeout))
.create_if_missing(true); .create_if_missing(true);
if let Some(ref exts) = __config.extensions {
for ext in exts {
*o = std::mem::take(o).extension(ext.clone());
}
}
} }
} }