diff --git a/contrib/codegen/src/database.rs b/contrib/codegen/src/database.rs index cb906f94..12bb068e 100644 --- a/contrib/codegen/src/database.rs +++ b/contrib/codegen/src/database.rs @@ -97,18 +97,13 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result(self, f: F) -> R + pub async fn run(&self, f: F) -> R where F: FnOnce(&mut #conn_type) -> R + Send + 'static, R: Send + 'static, { self.0.run(f).await } - - /// Asynchronously acquires another connection from the connection pool. - pub async fn clone(&mut self) -> ::std::result::Result { - self.0.clone().await.map(Self) - } } #[::rocket::async_trait] diff --git a/contrib/codegen/tests/ui-fail-nightly/database-types.stderr b/contrib/codegen/tests/ui-fail-nightly/database-types.stderr index ce769c9f..95e996f9 100644 --- a/contrib/codegen/tests/ui-fail-nightly/database-types.stderr +++ b/contrib/codegen/tests/ui-fail-nightly/database-types.stderr @@ -1,21 +1,21 @@ -error[E0277]: the trait bound `Unknown: rocket_contrib::databases::Poolable` is not satisfied +error[E0277]: the trait bound `Unknown: Poolable` is not satisfied --> $DIR/database-types.rs:7:10 | 7 | struct A(Unknown); - | ^^^^^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `Unknown` + | ^^^^^^^ the trait `Poolable` is not implemented for `Unknown` | - ::: $WORKSPACE/contrib/lib/src/databases.rs:832:29 + ::: $WORKSPACE/contrib/lib/src/databases.rs | -832 | pub struct Connection { + | pub struct Connection { | -------- required by this bound in `rocket_contrib::databases::Connection` -error[E0277]: the trait bound `std::vec::Vec: rocket_contrib::databases::Poolable` is not satisfied +error[E0277]: the trait bound `Vec: Poolable` is not satisfied --> $DIR/database-types.rs:10:10 | 10 | struct B(Vec); - | ^^^^^^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `std::vec::Vec` + | ^^^^^^^^ the trait `Poolable` is not implemented for `Vec` | - ::: $WORKSPACE/contrib/lib/src/databases.rs:832:29 + ::: $WORKSPACE/contrib/lib/src/databases.rs | -832 | pub struct Connection { + | pub struct Connection { | -------- required by this bound in `rocket_contrib::databases::Connection` diff --git a/contrib/lib/src/databases.rs b/contrib/lib/src/databases.rs index 039c4da6..2a71e6cc 100644 --- a/contrib/lib/src/databases.rs +++ b/contrib/lib/src/databases.rs @@ -191,6 +191,23 @@ //! database. This corresponds to the database name set as the database's //! configuration key. //! +//! The macro generates a [`FromRequest`] implementation for the decorated type, +//! allowing the type to be used as a request guard. This implementation +//! retrieves a connection from the database pool or fails with a +//! `Status::ServiceUnavailable` if connecting to the database times out. +//! +//! The macro will also generate two inherent methods on the decorated type: +//! +//! * `fn fairing() -> impl Fairing` +//! +//! Returns a fairing that initializes the associated database connection +//! pool. +//! +//! * `async fn get_one(&Cargo) -> Option` +//! +//! Retrieves a connection wrapper from the configured pool. Returns `Some` +//! as long as `Self::fairing()` has been attached. +//! //! The attribute can only be applied to unit-like structs with one type. The //! internal type of the structure must implement [`Poolable`]. //! @@ -221,23 +238,6 @@ //! # } //! ``` //! -//! The macro generates a [`FromRequest`] implementation for the decorated type, -//! allowing the type to be used as a request guard. This implementation -//! retrieves a connection from the database pool or fails with a -//! `Status::ServiceUnavailable` if connecting to the database times out. -//! -//! The macro will also generate two inherent methods on the decorated type: -//! -//! * `fn fairing() -> impl Fairing` -//! -//! Returns a fairing that initializes the associated database connection -//! pool. -//! -//! * `async fn get_one(&Cargo) -> Option` -//! -//! Retrieves a connection wrapper from the configured pool. Returns `Some` -//! as long as `Self::fairing()` has been attached. -//! //! The fairing returned from the generated `fairing()` method _must_ be //! attached for the request guard implementation to succeed. Putting the pieces //! together, a use of the `#[database]` attribute looks as follows: @@ -322,37 +322,6 @@ //! # } //! ``` //! -//! `run()` takes the connection by value. To make multiple calls to run, -//! obtain a second connection first with `clone()`: -//! -//! ```rust -//! # #[macro_use] extern crate rocket; -//! # #[macro_use] extern crate rocket_contrib; -//! # -//! # #[cfg(feature = "diesel_sqlite_pool")] -//! # mod test { -//! # use rocket_contrib::databases::diesel; -//! # type Data = (); -//! #[database("my_db")] -//! struct MyDatabase(diesel::SqliteConnection); -//! -//! fn load_from_db(conn: &diesel::SqliteConnection) -> Data { -//! // Do something with connection, return some data. -//! # () -//! } -//! -//! #[get("/")] -//! async fn my_handler(mut conn: MyDatabase) -> Data { -//! let cloned = conn.clone().await.expect(""); -//! cloned.run(|c| load_from_db(c)).await; -//! -//! // Do something else -//! conn.run(|c| load_from_db(c)).await; -//! } -//! # } -//! ``` -//! -//! //! # Database Support //! //! Built-in support is provided for many popular databases and drivers. Support @@ -418,10 +387,11 @@ use std::sync::Arc; use rocket::config::{self, Value}; use rocket::fairing::{AdHoc, Fairing}; +use rocket::request::{Request, Outcome, FromRequest}; +use rocket::outcome::IntoOutcome; use rocket::http::Status; -use rocket::request::Outcome; -use rocket::tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use rocket::tokio::sync::{OwnedSemaphorePermit, Semaphore, Mutex}; use self::r2d2::ManageConnection; @@ -827,29 +797,25 @@ pub struct ConnectionPool { /// types are properly checked. #[doc(hidden)] pub struct Connection { - pool: ConnectionPool, - connection: Option>, - _permit: Option, + connection: Arc>>>, + permit: Option, _marker: PhantomData K>, } -// A wrapper around spawn_blocking that propagates panics to the calling code +// A wrapper around spawn_blocking that propagates panics to the calling code. async fn run_blocking(job: F) -> R -where - F: FnOnce() -> R + Send + 'static, - R: Send + 'static, + where F: FnOnce() -> R + Send + 'static, R: Send + 'static, { match tokio::task::spawn_blocking(job).await { Ok(ret) => ret, Err(e) => match e.try_into_panic() { Ok(panic) => std::panic::resume_unwind(panic), - Err(_) => unreachable!("spawn_blocking tasks are never canceled"), + Err(_) => unreachable!("spawn_blocking tasks are never cancelled"), } } } impl ConnectionPool { - #[inline] pub fn fairing(fairing_name: &'static str, config_name: &'static str) -> impl Fairing { AdHoc::on_attach(fairing_name, move |mut rocket| async move { let config = database_config(config_name, rocket.config().await); @@ -881,9 +847,9 @@ impl ConnectionPool { } async fn get(&self) -> Result, ()> { - // TODO: timeout duration + // TODO: Make timeout configurable. let permit = match tokio::time::timeout( - self.pool.connection_timeout(), + std::time::Duration::from_secs(5), self.semaphore.clone().acquire_owned() ).await { Ok(p) => p, @@ -893,18 +859,14 @@ impl ConnectionPool { } }; + // TODO: Make timeout configurable. let pool = self.pool.clone(); - - // TODO: timeout duration - match run_blocking(move || pool.get_timeout(std::time::Duration::from_secs(0))).await { - Ok(c) => { - Ok(Connection { - pool: self.clone(), - connection: Some(c), - _permit: Some(permit), - _marker: PhantomData, - }) - } + match run_blocking(move || pool.get_timeout(std::time::Duration::from_secs(5))).await { + Ok(c) => Ok(Connection { + connection: Arc::new(Mutex::new(Some(c))), + permit: Some(permit), + _marker: PhantomData, + }), Err(e) => { error_!("Failed to get a database connection: {}", e); Err(()) @@ -924,64 +886,49 @@ impl ConnectionPool { } } -impl Clone for ConnectionPool { - fn clone(&self) -> Self { - Self { - pool: self.pool.clone(), - semaphore: self.semaphore.clone(), - _marker: PhantomData, - } - } -} - impl Connection { #[inline] - pub async fn run(self, f: F) -> R + pub async fn run(&self, f: F) -> R where F: FnOnce(&mut C) -> R + Send + 'static, R: Send + 'static, { + let mut connection = self.connection.clone().lock_owned().await; run_blocking(move || { - // 'self' contains a semaphore permit that should be held throughout - // this entire spawned task. Explicitly move it, so that a - // refactoring won't accidentally release the permit too early. - let mut this: Self = self; - - let mut conn = this.connection.take() + let conn = connection.as_mut() .expect("internal invariant broken: self.connection is Some"); - - f(&mut conn) + f(conn) }).await } - - #[inline] - pub async fn clone(&mut self) -> Result { - self.pool.get().await - } } impl Drop for Connection { fn drop(&mut self) { - if let Some(conn) = self.connection.take() { - tokio::task::spawn_blocking(|| drop(conn)); - } + let connection = self.connection.clone(); + let permit = self.permit.take(); + tokio::spawn(async move { + let mut connection = connection.lock_owned().await; + tokio::task::spawn_blocking(move || { + if let Some(conn) = connection.take() { + drop(conn); + } + // NB: Explicitly dropping the permit here so that it's only + // released after the connection is. + drop(permit); + }) + }); } } #[rocket::async_trait] -impl<'a, 'r, K: 'static, C: Poolable> rocket::request::FromRequest<'a, 'r> for Connection { +impl<'a, 'r, K: 'static, C: Poolable> FromRequest<'a, 'r> for Connection { type Error = (); #[inline] - async fn from_request(request: &'a rocket::request::Request<'r>) -> Outcome { + async fn from_request(request: &'a Request<'r>) -> Outcome { match request.managed_state::>() { - Some(inner) => { - match inner.get().await { - Ok(c) => Outcome::Success(c), - Err(()) => Outcome::Failure((Status::ServiceUnavailable, ())), - } - } + Some(c) => c.get().await.into_outcome(Status::ServiceUnavailable), None => { - error_!("Database fairing was not attached for {}", std::any::type_name::()); + error_!("Missing database fairing for `{}`", std::any::type_name::()); Outcome::Failure((Status::InternalServerError, ())) } } diff --git a/examples/todo/src/main.rs b/examples/todo/src/main.rs index a37f51e6..36b9b5b3 100644 --- a/examples/todo/src/main.rs +++ b/examples/todo/src/main.rs @@ -7,6 +7,8 @@ mod task; #[cfg(test)] mod tests; +use std::fmt::Display; + use rocket::Rocket; use rocket::fairing::AdHoc; use rocket::request::{Form, FlashMessage}; @@ -31,20 +33,20 @@ struct Context { } impl Context { - pub async fn err(conn: DbConn, msg: String) -> Context { + pub async fn err(conn: &DbConn, msg: M) -> Context { Context { - msg: Some(("error".to_string(), msg)), + msg: Some(("error".into(), msg.to_string())), tasks: Task::all(conn).await.unwrap_or_default() } } - pub async fn raw(conn: DbConn, msg: Option<(String, String)>) -> Context { + pub async fn raw(conn: &DbConn, msg: Option<(String, String)>) -> Context { match Task::all(conn).await { Ok(tasks) => Context { msg, tasks }, Err(e) => { error_!("DB Task::all() error: {}", e); Context { - msg: Some(("error".to_string(), "Couldn't access the task database.".to_string())), + msg: Some(("error".into(), "Fail to access database.".into())), tasks: vec![] } } @@ -57,7 +59,7 @@ async fn new(todo_form: Form, conn: DbConn) -> Flash { let todo = todo_form.into_inner(); if todo.description.is_empty() { Flash::error(Redirect::to("/"), "Description cannot be empty.") - } else if let Err(e) = Task::insert(todo, conn).await { + } else if let Err(e) = Task::insert(todo, &conn).await { error_!("DB insertion error: {}", e); Flash::error(Redirect::to("/"), "Todo could not be inserted due an internal error.") } else { @@ -66,27 +68,23 @@ async fn new(todo_form: Form, conn: DbConn) -> Flash { } #[put("/")] -async fn toggle(id: i32, mut conn: DbConn) -> Result { - // TODO - let conn2 = conn.clone().await.unwrap(); - match Task::toggle_with_id(id, conn).await { +async fn toggle(id: i32, conn: DbConn) -> Result { + match Task::toggle_with_id(id, &conn).await { Ok(_) => Ok(Redirect::to("/")), Err(e) => { error_!("DB toggle({}) error: {}", id, e); - Err(Template::render("index", Context::err(conn2, "Failed to toggle task.".to_string()).await)) + Err(Template::render("index", Context::err(&conn, "Failed to toggle task.").await)) } } } #[delete("/")] -async fn delete(id: i32, mut conn: DbConn) -> Result, Template> { - // TODO - let conn2 = conn.clone().await.unwrap(); - match Task::delete_with_id(id, conn).await { +async fn delete(id: i32, conn: DbConn) -> Result, Template> { + match Task::delete_with_id(id, &conn).await { Ok(_) => Ok(Flash::success(Redirect::to("/"), "Todo was deleted.")), Err(e) => { error_!("DB deletion({}) error: {}", id, e); - Err(Template::render("index", Context::err(conn2, "Failed to delete task.".to_string()).await)) + Err(Template::render("index", Context::err(&conn, "Failed to delete task.").await)) } } } @@ -94,20 +92,19 @@ async fn delete(id: i32, mut conn: DbConn) -> Result, Template> #[get("/")] async fn index(msg: Option>, conn: DbConn) -> Template { let msg = msg.map(|m| (m.name().to_string(), m.msg().to_string())); - Template::render("index", Context::raw(conn, msg).await) + Template::render("index", Context::raw(&conn, msg).await) } async fn run_db_migrations(mut rocket: Rocket) -> Result { - let conn = DbConn::get_one(rocket.inspect().await).await.expect("database connection"); - conn.run(|c| { - match embedded_migrations::run(c) { + DbConn::get_one(rocket.inspect().await).await + .expect("database connection") + .run(|c| match embedded_migrations::run(c) { Ok(()) => Ok(rocket), Err(e) => { error!("Failed to run database migrations: {:?}", e); Err(rocket) } - } - }).await + }).await } #[launch] diff --git a/examples/todo/src/task.rs b/examples/todo/src/task.rs index 0e9ad8a8..98726819 100644 --- a/examples/todo/src/task.rs +++ b/examples/todo/src/task.rs @@ -29,14 +29,14 @@ pub struct Todo { } impl Task { - pub async fn all(conn: DbConn) -> QueryResult> { + pub async fn all(conn: &DbConn) -> QueryResult> { conn.run(|c| { all_tasks.order(tasks::id.desc()).load::(c) }).await } /// Returns the number of affected rows: 1. - pub async fn insert(todo: Todo, conn: DbConn) -> QueryResult { + pub async fn insert(todo: Todo, conn: &DbConn) -> QueryResult { conn.run(|c| { let t = Task { id: None, description: todo.description, completed: false }; diesel::insert_into(tasks::table).values(&t).execute(c) @@ -44,7 +44,7 @@ impl Task { } /// Returns the number of affected rows: 1. - pub async fn toggle_with_id(id: i32, conn: DbConn) -> QueryResult { + pub async fn toggle_with_id(id: i32, conn: &DbConn) -> QueryResult { conn.run(move |c| { let task = all_tasks.find(id).get_result::(c)?; let new_status = !task.completed; @@ -54,13 +54,13 @@ impl Task { } /// Returns the number of affected rows: 1. - pub async fn delete_with_id(id: i32, conn: DbConn) -> QueryResult { + pub async fn delete_with_id(id: i32, conn: &DbConn) -> QueryResult { conn.run(move |c| diesel::delete(all_tasks.find(id)).execute(c)).await } /// Returns the number of affected rows. #[cfg(test)] - pub async fn delete_all(conn: DbConn) -> QueryResult { + pub async fn delete_all(conn: &DbConn) -> QueryResult { conn.run(|c| diesel::delete(all_tasks).execute(c)).await } } diff --git a/examples/todo/src/tests.rs b/examples/todo/src/tests.rs index 73930b80..a184945b 100644 --- a/examples/todo/src/tests.rs +++ b/examples/todo/src/tests.rs @@ -19,9 +19,8 @@ macro_rules! run_test { let rocket = super::rocket(); let $client = Client::new(rocket).await.expect("Rocket client"); let db = super::DbConn::get_one($client.cargo()).await; - let mut $conn = db.expect("failed to get database connection for testing"); - let delete_conn = $conn.clone().await.expect("failed to get a second database connection for testing"); - Task::delete_all(delete_conn).await.expect("failed to delete all tasks for testing"); + let $conn = db.expect("failed to get database connection for testing"); + Task::delete_all(&$conn).await.expect("failed to delete all tasks for testing"); $block }) @@ -32,7 +31,7 @@ macro_rules! run_test { fn test_insertion_deletion() { run_test!(|client, conn| { // Get the tasks before making changes. - let init_tasks = Task::all(conn.clone().await.unwrap()).await.unwrap(); + let init_tasks = Task::all(&conn).await.unwrap(); // Issue a request to insert a new task. client.post("/todo") @@ -42,7 +41,7 @@ fn test_insertion_deletion() { .await; // Ensure we have one more task in the database. - let new_tasks = Task::all(conn.clone().await.unwrap()).await.unwrap(); + let new_tasks = Task::all(&conn).await.unwrap(); assert_eq!(new_tasks.len(), init_tasks.len() + 1); // Ensure the task is what we expect. @@ -54,7 +53,7 @@ fn test_insertion_deletion() { client.delete(format!("/todo/{}", id)).dispatch().await; // Ensure it's gone. - let final_tasks = Task::all(conn).await.unwrap(); + let final_tasks = Task::all(&conn).await.unwrap(); assert_eq!(final_tasks.len(), init_tasks.len()); if final_tasks.len() > 0 { assert_ne!(final_tasks[0].description, "My first task"); @@ -72,16 +71,16 @@ fn test_toggle() { .dispatch() .await; - let task = Task::all(conn.clone().await.unwrap()).await.unwrap()[0].clone(); + let task = Task::all(&conn).await.unwrap()[0].clone(); assert_eq!(task.completed, false); // Issue a request to toggle the task; ensure it is completed. client.put(format!("/todo/{}", task.id.unwrap())).dispatch().await; - assert_eq!(Task::all(conn.clone().await.unwrap()).await.unwrap()[0].completed, true); + assert_eq!(Task::all(&conn).await.unwrap()[0].completed, true); // Issue a request to toggle the task; ensure it's not completed again. client.put(format!("/todo/{}", task.id.unwrap())).dispatch().await; - assert_eq!(Task::all(conn).await.unwrap()[0].completed, false); + assert_eq!(Task::all(&conn).await.unwrap()[0].completed, false); }) } @@ -91,7 +90,7 @@ fn test_many_insertions() { run_test!(|client, conn| { // Get the number of tasks initially. - let init_num = Task::all(conn.clone().await.unwrap()).await.unwrap().len(); + let init_num = Task::all(&conn).await.unwrap().len(); let mut descs = Vec::new(); for i in 0..ITER { @@ -107,7 +106,7 @@ fn test_many_insertions() { descs.insert(0, desc); // Ensure the task was inserted properly and all other tasks remain. - let tasks = Task::all(conn.clone().await.unwrap()).await.unwrap(); + let tasks = Task::all(&conn).await.unwrap(); assert_eq!(tasks.len(), init_num + i + 1); for j in 0..i {