Use 'spawn_blocking' to drop sync database pools.

This was already done for the connections, but pools might also do
synchronous/blocking work on Drop.

Fixes #1466.
This commit is contained in:
Jeb Rosen 2020-11-04 20:05:31 -08:00 committed by Sergio Benitez
parent 2f98299272
commit c6298b9e11
2 changed files with 66 additions and 3 deletions

View File

@ -707,7 +707,8 @@ impl Poolable for memcache::Client {
#[doc(hidden)] #[doc(hidden)]
pub struct ConnectionPool<K, C: Poolable> { pub struct ConnectionPool<K, C: Poolable> {
config: Config, config: Config,
pool: r2d2::Pool<C::Manager>, // This is an 'Option' so that we can drop the pool in a 'spawn_blocking'.
pool: Option<r2d2::Pool<C::Manager>>,
semaphore: Arc<Semaphore>, semaphore: Arc<Semaphore>,
_marker: PhantomData<fn() -> K>, _marker: PhantomData<fn() -> K>,
} }
@ -766,7 +767,8 @@ impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
let pool_size = config.pool_size; let pool_size = config.pool_size;
match C::pool(db, &rocket) { match C::pool(db, &rocket) {
Ok(pool) => Ok(rocket.manage(ConnectionPool::<K, C> { Ok(pool) => Ok(rocket.manage(ConnectionPool::<K, C> {
pool, config, config,
pool: Some(pool),
semaphore: Arc::new(Semaphore::new(pool_size as usize)), semaphore: Arc::new(Semaphore::new(pool_size as usize)),
_marker: PhantomData, _marker: PhantomData,
})), })),
@ -787,7 +789,9 @@ impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
} }
}; };
let pool = self.pool.clone(); let pool = self.pool.as_ref().cloned()
.expect("internal invariant broken: self.pool is Some");
match run_blocking(move || pool.get_timeout(duration)).await { match run_blocking(move || pool.get_timeout(duration)).await {
Ok(c) => Ok(Connection { Ok(c) => Ok(Connection {
connection: Arc::new(Mutex::new(Some(c))), connection: Arc::new(Mutex::new(Some(c))),
@ -849,6 +853,13 @@ impl<K, C: Poolable> Drop for Connection<K, C> {
} }
} }
impl<K, C: Poolable> Drop for ConnectionPool<K, C> {
fn drop(&mut self) {
let pool = self.pool.take();
tokio::task::spawn_blocking(move || drop(pool));
}
}
#[rocket::async_trait] #[rocket::async_trait]
impl<'a, 'r, K: 'static, C: Poolable> FromRequest<'a, 'r> for Connection<K, C> { impl<'a, 'r, K: 'static, C: Poolable> FromRequest<'a, 'r> for Connection<K, C> {
type Error = (); type Error = ();

View File

@ -52,3 +52,55 @@ mod rusqlite_integration_test {
}).await; }).await;
} }
} }
#[cfg(feature = "databases")]
#[cfg(test)]
mod drop_runtime_test {
use r2d2::{ManageConnection, Pool};
use rocket_contrib::databases::{database, Poolable, PoolResult};
use tokio::runtime::Runtime;
struct ContainsRuntime(Runtime);
struct TestConnection;
impl ManageConnection for ContainsRuntime {
type Connection = TestConnection;
type Error = std::convert::Infallible;
fn connect(&self) -> Result<Self::Connection, Self::Error> {
Ok(TestConnection)
}
fn is_valid(&self, _conn: &mut Self::Connection) -> Result<(), Self::Error> {
Ok(())
}
fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
false
}
}
impl Poolable for TestConnection {
type Manager = ContainsRuntime;
type Error = ();
fn pool(_db_name: &str, _rocket: &rocket::Rocket) -> PoolResult<Self> {
let manager = ContainsRuntime(tokio::runtime::Runtime::new().unwrap());
Ok(Pool::builder().build(manager)?)
}
}
#[database("test_db")]
struct TestDb(TestConnection);
#[rocket::async_test]
async fn test_drop_runtime() {
use rocket::figment::{Figment, util::map};
let config = Figment::from(rocket::Config::default())
.merge(("databases", map!["test_db" => map!["url" => ""]]));
let rocket = rocket::custom(config).attach(TestDb::fairing());
drop(rocket);
}
}