From bc8c5b9ee217e6acf16f9b6d855fb19ee1867a5d Mon Sep 17 00:00:00 2001 From: Jeb Rosen Date: Sat, 11 Jul 2020 11:17:43 -0700 Subject: [PATCH] Use 'spawn_blocking' in '#[database]'. The connection guard type generated by `#[database]` no longer implements `Deref` and `DerefMut`. Instead, it provides an `async fn run()` that gives access to the underlying connection on a closure run through `spawn_blocking()`. Additionally moves most of the implementation of `#[database]` out of generated code and into library code for better type-checking. --- contrib/codegen/src/database.rs | 79 ++---- .../ui-fail-nightly/database-types.stderr | 30 ++- .../ui-fail-stable/database-types.stderr | 26 +- contrib/lib/src/databases.rs | 248 ++++++++++++++++-- contrib/lib/tests/databases.rs | 41 ++- examples/todo/src/main.rs | 77 +++--- examples/todo/src/task.rs | 36 ++- examples/todo/src/tests.rs | 57 ++-- site/guide/6-state.md | 18 +- 9 files changed, 412 insertions(+), 200 deletions(-) diff --git a/contrib/codegen/src/database.rs b/contrib/codegen/src/database.rs index 1b0124d3..cb906f94 100644 --- a/contrib/codegen/src/database.rs +++ b/contrib/codegen/src/database.rs @@ -64,23 +64,16 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result ::rocket_contrib::databases); - let Poolable = quote_spanned!(span => #databases::Poolable); - let r2d2 = quote_spanned!(span => #databases::r2d2); - let spawn_blocking = quote_spanned!(span => #databases::spawn_blocking); let request = quote!(::rocket::request); let generated_types = quote_spanned! { span => /// The request guard type. - #vis struct #guard_type(pub #r2d2::PooledConnection<<#conn_type as #Poolable>::Manager>); - - /// The pool type. - #vis struct #pool_type(#r2d2::Pool<<#conn_type as #Poolable>::Manager>); + #vis struct #guard_type(#databases::Connection); }; Ok(quote! { @@ -90,53 +83,31 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result impl ::rocket::fairing::Fairing { - use #databases::Poolable; - - ::rocket::fairing::AdHoc::on_attach(#fairing_name, |mut rocket| async { - let pool = #databases::database_config(#name, rocket.config().await) - .map(<#conn_type>::pool); - - match pool { - Ok(Ok(p)) => Ok(rocket.manage(#pool_type(p))), - Err(config_error) => { - ::rocket::logger::error( - &format!("Database configuration failure: '{}'", #name)); - ::rocket::logger::error_(&format!("{}", config_error)); - Err(rocket) - }, - Ok(Err(pool_error)) => { - ::rocket::logger::error( - &format!("Failed to initialize pool for '{}'", #name)); - ::rocket::logger::error_(&format!("{:?}", pool_error)); - Err(rocket) - }, - } - }) + <#databases::ConnectionPool>::fairing(#fairing_name, #name) } /// Retrieves a connection of type `Self` from the `rocket` /// instance. Returns `Some` as long as `Self::fairing()` has been - /// attached and there is at least one connection in the pool. - pub fn get_one(cargo: &::rocket::Cargo) -> Option { - cargo.state::<#pool_type>() - .and_then(|pool| pool.0.get().ok()) - .map(#guard_type) + /// attached. + pub async fn get_one(cargo: &::rocket::Cargo) -> Option { + <#databases::ConnectionPool>::get_one(cargo).await.map(Self) } - } - impl ::std::ops::Deref for #guard_type { - type Target = #conn_type; - - #[inline(always)] - fn deref(&self) -> &Self::Target { - &self.0 + /// Runs the provided closure on a thread from a threadpool. The + /// closure will be passed an `&mut r2d2::PooledConnection`. + /// `.await`ing the return value of this function yields the value + /// returned by the closure. + pub async fn run(self, f: F) -> R + where + F: FnOnce(&mut #conn_type) -> R + Send + 'static, + R: Send + 'static, + { + self.0.run(f).await } - } - impl ::std::ops::DerefMut for #guard_type { - #[inline(always)] - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 + /// Asynchronously acquires another connection from the connection pool. + pub async fn clone(&mut self) -> ::std::result::Result { + self.0.clone().await.map(Self) } } @@ -145,20 +116,8 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result) -> #request::Outcome { - use ::rocket::http::Status; - - let guard = request.guard::<::rocket::State<'_, #pool_type>>(); - let pool = ::rocket::try_outcome!(guard.await).0.clone(); - - #spawn_blocking(move || { - match pool.get() { - Ok(conn) => #request::Outcome::Success(#guard_type(conn)), - Err(_) => #request::Outcome::Failure((Status::ServiceUnavailable, ())), - } - }).await.expect("failed to spawn a blocking task to get a pooled connection") + <#databases::Connection>::from_request(request).await.map(Self) } } - - // TODO.async: What about spawn_blocking on drop? }.into()) } diff --git a/contrib/codegen/tests/ui-fail-nightly/database-types.stderr b/contrib/codegen/tests/ui-fail-nightly/database-types.stderr index 61b7c741..ce769c9f 100644 --- a/contrib/codegen/tests/ui-fail-nightly/database-types.stderr +++ b/contrib/codegen/tests/ui-fail-nightly/database-types.stderr @@ -1,11 +1,21 @@ -error[E0277]: the trait bound `Unknown: Poolable` is not satisfied - --> $DIR/database-types.rs:7:10 - | -7 | struct A(Unknown); - | ^^^^^^^ the trait `Poolable` is not implemented for `Unknown` +error[E0277]: the trait bound `Unknown: rocket_contrib::databases::Poolable` is not satisfied + --> $DIR/database-types.rs:7:10 + | +7 | struct A(Unknown); + | ^^^^^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `Unknown` + | + ::: $WORKSPACE/contrib/lib/src/databases.rs:832:29 + | +832 | pub struct Connection { + | -------- required by this bound in `rocket_contrib::databases::Connection` -error[E0277]: the trait bound `Vec: Poolable` is not satisfied - --> $DIR/database-types.rs:10:10 - | -10 | struct B(Vec); - | ^^^^^^^^ the trait `Poolable` is not implemented for `Vec` +error[E0277]: the trait bound `std::vec::Vec: rocket_contrib::databases::Poolable` is not satisfied + --> $DIR/database-types.rs:10:10 + | +10 | struct B(Vec); + | ^^^^^^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `std::vec::Vec` + | + ::: $WORKSPACE/contrib/lib/src/databases.rs:832:29 + | +832 | pub struct Connection { + | -------- required by this bound in `rocket_contrib::databases::Connection` diff --git a/contrib/codegen/tests/ui-fail-stable/database-types.stderr b/contrib/codegen/tests/ui-fail-stable/database-types.stderr index 7b85009e..dcc96154 100644 --- a/contrib/codegen/tests/ui-fail-stable/database-types.stderr +++ b/contrib/codegen/tests/ui-fail-stable/database-types.stderr @@ -1,11 +1,21 @@ error[E0277]: the trait bound `Unknown: rocket_contrib::databases::Poolable` is not satisfied - --> $DIR/database-types.rs:7:10 - | -7 | struct A(Unknown); - | ^^^^^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `Unknown` + --> $DIR/database-types.rs:7:10 + | +7 | struct A(Unknown); + | ^^^^^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `Unknown` + | + ::: $WORKSPACE/contrib/lib/src/databases.rs + | + | pub struct Connection { + | -------- required by this bound in `rocket_contrib::databases::Connection` error[E0277]: the trait bound `std::vec::Vec: rocket_contrib::databases::Poolable` is not satisfied - --> $DIR/database-types.rs:10:10 - | -10 | struct B(Vec); - | ^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `std::vec::Vec` + --> $DIR/database-types.rs:10:10 + | +10 | struct B(Vec); + | ^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `std::vec::Vec` + | + ::: $WORKSPACE/contrib/lib/src/databases.rs + | + | pub struct Connection { + | -------- required by this bound in `rocket_contrib::databases::Connection` diff --git a/contrib/lib/src/databases.rs b/contrib/lib/src/databases.rs index d89af585..039c4da6 100644 --- a/contrib/lib/src/databases.rs +++ b/contrib/lib/src/databases.rs @@ -84,9 +84,9 @@ //! # type Result = std::result::Result; //! # //! #[get("/logs/")] -//! fn get_logs(conn: LogsDbConn, id: usize) -> Result { +//! async fn get_logs(conn: LogsDbConn, id: usize) -> Result { //! # /* -//! Logs::by_id(&*conn, id) +//! conn.run(|c| Logs::by_id(c, id)).await //! # */ //! # Ok(()) //! } @@ -224,9 +224,7 @@ //! The macro generates a [`FromRequest`] implementation for the decorated type, //! allowing the type to be used as a request guard. This implementation //! retrieves a connection from the database pool or fails with a -//! `Status::ServiceUnavailable` if no connections are available. The macro also -//! generates an implementation of the [`Deref`](std::ops::Deref) trait with -//! the internal `Poolable` type as the target. +//! `Status::ServiceUnavailable` if connecting to the database times out. //! //! The macro will also generate two inherent methods on the decorated type: //! @@ -235,11 +233,10 @@ //! Returns a fairing that initializes the associated database connection //! pool. //! -//! * `fn get_one(&Cargo) -> Option` +//! * `async fn get_one(&Cargo) -> Option` //! -//! Retrieves a connection from the configured pool. Returns `Some` as long -//! as `Self::fairing()` has been attached and there is at least one -//! connection in the pool. +//! Retrieves a connection wrapper from the configured pool. Returns `Some` +//! as long as `Self::fairing()` has been attached. //! //! The fairing returned from the generated `fairing()` method _must_ be //! attached for the request guard implementation to succeed. Putting the pieces @@ -280,8 +277,8 @@ //! //! ## Handlers //! -//! Finally, simply use your type as a request guard in a handler to retrieve a -//! connection to a given database: +//! Finally, use your type as a request guard in a handler to retrieve a +//! connection wrapper for the database: //! //! ```rust //! # #[macro_use] extern crate rocket; @@ -300,8 +297,7 @@ //! # } //! ``` //! -//! The generated `Deref` implementation allows easy access to the inner -//! connection type: +//! A connection can be retrieved and used with the `run()` method: //! //! ```rust //! # #[macro_use] extern crate rocket; @@ -320,12 +316,43 @@ //! } //! //! #[get("/")] -//! fn my_handler(conn: MyDatabase) -> Data { -//! load_from_db(&*conn) +//! async fn my_handler(mut conn: MyDatabase) -> Data { +//! conn.run(|c| load_from_db(c)).await //! } //! # } //! ``` //! +//! `run()` takes the connection by value. To make multiple calls to run, +//! obtain a second connection first with `clone()`: +//! +//! ```rust +//! # #[macro_use] extern crate rocket; +//! # #[macro_use] extern crate rocket_contrib; +//! # +//! # #[cfg(feature = "diesel_sqlite_pool")] +//! # mod test { +//! # use rocket_contrib::databases::diesel; +//! # type Data = (); +//! #[database("my_db")] +//! struct MyDatabase(diesel::SqliteConnection); +//! +//! fn load_from_db(conn: &diesel::SqliteConnection) -> Data { +//! // Do something with connection, return some data. +//! # () +//! } +//! +//! #[get("/")] +//! async fn my_handler(mut conn: MyDatabase) -> Data { +//! let cloned = conn.clone().await.expect(""); +//! cloned.run(|c| load_from_db(c)).await; +//! +//! // Do something else +//! conn.run(|c| load_from_db(c)).await; +//! } +//! # } +//! ``` +//! +//! //! # Database Support //! //! Built-in support is provided for many popular databases and drivers. Support @@ -380,18 +407,21 @@ pub extern crate r2d2; -#[doc(hidden)] -pub use tokio::task::spawn_blocking; - #[cfg(any(feature = "diesel_sqlite_pool", feature = "diesel_postgres_pool", feature = "diesel_mysql_pool"))] pub extern crate diesel; use std::fmt::{self, Display, Formatter}; -use std::marker::{Send, Sized}; +use std::marker::PhantomData; +use std::sync::Arc; use rocket::config::{self, Value}; +use rocket::fairing::{AdHoc, Fairing}; +use rocket::http::Status; +use rocket::request::Outcome; + +use rocket::tokio::sync::{OwnedSemaphorePermit, Semaphore}; use self::r2d2::ManageConnection; @@ -688,7 +718,7 @@ pub trait Poolable: Send + Sized + 'static { type Manager: ManageConnection; /// The associated error type in the event that constructing the connection /// manager and/or the connection pool fails. - type Error; + type Error: std::fmt::Debug; /// Creates an `r2d2` connection pool for `Manager::Connection`, returning /// the pool on success. @@ -780,6 +810,184 @@ impl Poolable for memcache::Client { } } +/// Unstable internal details of generated code for the #[database] attribute. +/// +/// This type is implemented here instead of in generated code to ensure all +/// types are properly checked. +#[doc(hidden)] +pub struct ConnectionPool { + pool: r2d2::Pool, + semaphore: Arc, + _marker: PhantomData K>, +} + +/// Unstable internal details of generated code for the #[database] attribute. +/// +/// This type is implemented here instead of in generated code to ensure all +/// types are properly checked. +#[doc(hidden)] +pub struct Connection { + pool: ConnectionPool, + connection: Option>, + _permit: Option, + _marker: PhantomData K>, +} + +// A wrapper around spawn_blocking that propagates panics to the calling code +async fn run_blocking(job: F) -> R +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + match tokio::task::spawn_blocking(job).await { + Ok(ret) => ret, + Err(e) => match e.try_into_panic() { + Ok(panic) => std::panic::resume_unwind(panic), + Err(_) => unreachable!("spawn_blocking tasks are never canceled"), + } + } +} + +impl ConnectionPool { + #[inline] + pub fn fairing(fairing_name: &'static str, config_name: &'static str) -> impl Fairing { + AdHoc::on_attach(fairing_name, move |mut rocket| async move { + let config = database_config(config_name, rocket.config().await); + let pool = config.map(|c| (c.pool_size, C::pool(c))); + + match pool { + Ok((size, Ok(pool))) => { + let managed = ConnectionPool:: { + pool, + semaphore: Arc::new(Semaphore::new(size as usize)), + _marker: PhantomData, + }; + Ok(rocket.manage(managed)) + }, + Err(config_error) => { + rocket::logger::error( + &format!("Database configuration failure: '{}'", config_name)); + rocket::logger::error_(&config_error.to_string()); + Err(rocket) + }, + Ok((_, Err(pool_error))) => { + rocket::logger::error( + &format!("Failed to initialize pool for '{}'", config_name)); + rocket::logger::error_(&format!("{:?}", pool_error)); + Err(rocket) + }, + } + }) + } + + async fn get(&self) -> Result, ()> { + // TODO: timeout duration + let permit = match tokio::time::timeout( + self.pool.connection_timeout(), + self.semaphore.clone().acquire_owned() + ).await { + Ok(p) => p, + Err(_) => { + error_!("Failed to get a database connection within the timeout."); + return Err(()); + } + }; + + let pool = self.pool.clone(); + + // TODO: timeout duration + match run_blocking(move || pool.get_timeout(std::time::Duration::from_secs(0))).await { + Ok(c) => { + Ok(Connection { + pool: self.clone(), + connection: Some(c), + _permit: Some(permit), + _marker: PhantomData, + }) + } + Err(e) => { + error_!("Failed to get a database connection: {}", e); + Err(()) + } + } + } + + #[inline] + pub async fn get_one(cargo: &rocket::Cargo) -> Option> { + match cargo.state::() { + Some(pool) => pool.get().await.ok(), + None => { + error_!("Database fairing was not attached for {}", std::any::type_name::()); + None + } + } + } +} + +impl Clone for ConnectionPool { + fn clone(&self) -> Self { + Self { + pool: self.pool.clone(), + semaphore: self.semaphore.clone(), + _marker: PhantomData, + } + } +} + +impl Connection { + #[inline] + pub async fn run(self, f: F) -> R + where F: FnOnce(&mut C) -> R + Send + 'static, + R: Send + 'static, + { + run_blocking(move || { + // 'self' contains a semaphore permit that should be held throughout + // this entire spawned task. Explicitly move it, so that a + // refactoring won't accidentally release the permit too early. + let mut this: Self = self; + + let mut conn = this.connection.take() + .expect("internal invariant broken: self.connection is Some"); + + f(&mut conn) + }).await + } + + #[inline] + pub async fn clone(&mut self) -> Result { + self.pool.get().await + } +} + +impl Drop for Connection { + fn drop(&mut self) { + if let Some(conn) = self.connection.take() { + tokio::task::spawn_blocking(|| drop(conn)); + } + } +} + +#[rocket::async_trait] +impl<'a, 'r, K: 'static, C: Poolable> rocket::request::FromRequest<'a, 'r> for Connection { + type Error = (); + + #[inline] + async fn from_request(request: &'a rocket::request::Request<'r>) -> Outcome { + match request.managed_state::>() { + Some(inner) => { + match inner.get().await { + Ok(c) => Outcome::Success(c), + Err(()) => Outcome::Failure((Status::ServiceUnavailable, ())), + } + } + None => { + error_!("Database fairing was not attached for {}", std::any::type_name::()); + Outcome::Failure((Status::InternalServerError, ())) + } + } + } +} + #[cfg(test)] mod tests { use std::collections::BTreeMap; diff --git a/contrib/lib/tests/databases.rs b/contrib/lib/tests/databases.rs index 47ead732..008e05c4 100644 --- a/contrib/lib/tests/databases.rs +++ b/contrib/lib/tests/databases.rs @@ -21,40 +21,31 @@ mod rusqlite_integration_test { #[database("test_db")] struct SqliteDb(pub rusqlite::Connection); + // Test to ensure that multiple databases of the same type can be used + #[database("test_db_2")] + struct SqliteDb2(pub rusqlite::Connection); + #[rocket::async_test] - async fn deref_mut_impl_present() { + async fn test_db() { let mut test_db: Map = Map::new(); let mut test_db_opts: Map = Map::new(); test_db_opts.insert("url".into(), Value::String(":memory:".into())); - test_db.insert("test_db".into(), Value::Table(test_db_opts)); + test_db.insert("test_db".into(), Value::Table(test_db_opts.clone())); + test_db.insert("test_db_2".into(), Value::Table(test_db_opts)); let config = Config::build(Environment::Development) .extra("databases", Value::Table(test_db)) .finalize() .unwrap(); - let mut rocket = rocket::custom(config).attach(SqliteDb::fairing()); - let mut conn = SqliteDb::get_one(rocket.inspect().await).expect("unable to get connection"); + let mut rocket = rocket::custom(config).attach(SqliteDb::fairing()).attach(SqliteDb2::fairing()); + let conn = SqliteDb::get_one(rocket.inspect().await).await.expect("unable to get connection"); - // Rusqlite's `transaction()` method takes `&mut self`; this tests the - // presence of a `DerefMut` trait on the generated connection type. - let tx = conn.transaction().unwrap(); - let _: i32 = tx.query_row("SELECT 1", &[] as &[&dyn ToSql], |row| row.get(0)).expect("get row"); - tx.commit().expect("committed transaction"); - } - - #[rocket::async_test] - async fn deref_impl_present() { - let mut test_db: Map = Map::new(); - let mut test_db_opts: Map = Map::new(); - test_db_opts.insert("url".into(), Value::String(":memory:".into())); - test_db.insert("test_db".into(), Value::Table(test_db_opts)); - let config = Config::build(Environment::Development) - .extra("databases", Value::Table(test_db)) - .finalize() - .unwrap(); - - let mut rocket = rocket::custom(config).attach(SqliteDb::fairing()); - let conn = SqliteDb::get_one(rocket.inspect().await).expect("unable to get connection"); - let _: i32 = conn.query_row("SELECT 1", &[] as &[&dyn ToSql], |row| row.get(0)).expect("get row"); + // Rusqlite's `transaction()` method takes `&mut self`; this tests that + // the &mut method can be called inside the closure passed to `run()`. + conn.run(|conn| { + let tx = conn.transaction().unwrap(); + let _: i32 = tx.query_row("SELECT 1", &[] as &[&dyn ToSql], |row| row.get(0)).expect("get row"); + tx.commit().expect("committed transaction"); + }).await; } } diff --git a/examples/todo/src/main.rs b/examples/todo/src/main.rs index 67563f96..a37f51e6 100644 --- a/examples/todo/src/main.rs +++ b/examples/todo/src/main.rs @@ -25,23 +25,26 @@ embed_migrations!(); pub struct DbConn(SqliteConnection); #[derive(Debug, serde::Serialize)] -struct Context<'a> { - msg: Option<(&'a str, &'a str)>, +struct Context { + msg: Option<(String, String)>, tasks: Vec } -impl<'a> Context<'a> { - pub fn err(conn: &DbConn, msg: &'a str) -> Context<'a> { - Context { msg: Some(("error", msg)), tasks: Task::all(conn).unwrap_or_default() } +impl Context { + pub async fn err(conn: DbConn, msg: String) -> Context { + Context { + msg: Some(("error".to_string(), msg)), + tasks: Task::all(conn).await.unwrap_or_default() + } } - pub fn raw(conn: &DbConn, msg: Option<(&'a str, &'a str)>) -> Context<'a> { - match Task::all(conn) { + pub async fn raw(conn: DbConn, msg: Option<(String, String)>) -> Context { + match Task::all(conn).await { Ok(tasks) => Context { msg, tasks }, Err(e) => { error_!("DB Task::all() error: {}", e); Context { - msg: Some(("error", "Couldn't access the task database.")), + msg: Some(("error".to_string(), "Couldn't access the task database.".to_string())), tasks: vec![] } } @@ -50,11 +53,11 @@ impl<'a> Context<'a> { } #[post("/", data = "")] -fn new(todo_form: Form, conn: DbConn) -> Flash { +async fn new(todo_form: Form, conn: DbConn) -> Flash { let todo = todo_form.into_inner(); if todo.description.is_empty() { Flash::error(Redirect::to("/"), "Description cannot be empty.") - } else if let Err(e) = Task::insert(todo, &conn) { + } else if let Err(e) = Task::insert(todo, conn).await { error_!("DB insertion error: {}", e); Flash::error(Redirect::to("/"), "Todo could not be inserted due an internal error.") } else { @@ -63,42 +66,48 @@ fn new(todo_form: Form, conn: DbConn) -> Flash { } #[put("/")] -fn toggle(id: i32, conn: DbConn) -> Result { - Task::toggle_with_id(id, &conn) - .map(|_| Redirect::to("/")) - .map_err(|e| { +async fn toggle(id: i32, mut conn: DbConn) -> Result { + // TODO + let conn2 = conn.clone().await.unwrap(); + match Task::toggle_with_id(id, conn).await { + Ok(_) => Ok(Redirect::to("/")), + Err(e) => { error_!("DB toggle({}) error: {}", id, e); - Template::render("index", Context::err(&conn, "Failed to toggle task.")) - }) + Err(Template::render("index", Context::err(conn2, "Failed to toggle task.".to_string()).await)) + } + } } #[delete("/")] -fn delete(id: i32, conn: DbConn) -> Result, Template> { - Task::delete_with_id(id, &conn) - .map(|_| Flash::success(Redirect::to("/"), "Todo was deleted.")) - .map_err(|e| { +async fn delete(id: i32, mut conn: DbConn) -> Result, Template> { + // TODO + let conn2 = conn.clone().await.unwrap(); + match Task::delete_with_id(id, conn).await { + Ok(_) => Ok(Flash::success(Redirect::to("/"), "Todo was deleted.")), + Err(e) => { error_!("DB deletion({}) error: {}", id, e); - Template::render("index", Context::err(&conn, "Failed to delete task.")) - }) + Err(Template::render("index", Context::err(conn2, "Failed to delete task.".to_string()).await)) + } + } } #[get("/")] -fn index(msg: Option>, conn: DbConn) -> Template { - Template::render("index", match msg { - Some(ref msg) => Context::raw(&conn, Some((msg.name(), msg.msg()))), - None => Context::raw(&conn, None), - }) +async fn index(msg: Option>, conn: DbConn) -> Template { + let msg = msg.map(|m| (m.name().to_string(), m.msg().to_string())); + Template::render("index", Context::raw(conn, msg).await) } async fn run_db_migrations(mut rocket: Rocket) -> Result { - let conn = DbConn::get_one(rocket.inspect().await).expect("database connection"); - match embedded_migrations::run(&*conn) { - Ok(()) => Ok(rocket), - Err(e) => { - error!("Failed to run database migrations: {:?}", e); - Err(rocket) + let conn = DbConn::get_one(rocket.inspect().await).await.expect("database connection"); + conn.run(|c| { + match embedded_migrations::run(c) { + Ok(()) => Ok(rocket), + Err(e) => { + error!("Failed to run database migrations: {:?}", e); + Err(rocket) + } } - } + }).await } #[launch] diff --git a/examples/todo/src/task.rs b/examples/todo/src/task.rs index 89167446..0e9ad8a8 100644 --- a/examples/todo/src/task.rs +++ b/examples/todo/src/task.rs @@ -13,6 +13,8 @@ mod schema { use self::schema::tasks; use self::schema::tasks::dsl::{tasks as all_tasks, completed as task_completed}; +use crate::DbConn; + #[table_name="tasks"] #[derive(serde::Serialize, Queryable, Insertable, Debug, Clone)] pub struct Task { @@ -27,32 +29,38 @@ pub struct Todo { } impl Task { - pub fn all(conn: &SqliteConnection) -> QueryResult> { - all_tasks.order(tasks::id.desc()).load::(conn) + pub async fn all(conn: DbConn) -> QueryResult> { + conn.run(|c| { + all_tasks.order(tasks::id.desc()).load::(c) + }).await } /// Returns the number of affected rows: 1. - pub fn insert(todo: Todo, conn: &SqliteConnection) -> QueryResult { - let t = Task { id: None, description: todo.description, completed: false }; - diesel::insert_into(tasks::table).values(&t).execute(conn) + pub async fn insert(todo: Todo, conn: DbConn) -> QueryResult { + conn.run(|c| { + let t = Task { id: None, description: todo.description, completed: false }; + diesel::insert_into(tasks::table).values(&t).execute(c) + }).await } /// Returns the number of affected rows: 1. - pub fn toggle_with_id(id: i32, conn: &SqliteConnection) -> QueryResult { - let task = all_tasks.find(id).get_result::(conn)?; - let new_status = !task.completed; - let updated_task = diesel::update(all_tasks.find(id)); - updated_task.set(task_completed.eq(new_status)).execute(conn) + pub async fn toggle_with_id(id: i32, conn: DbConn) -> QueryResult { + conn.run(move |c| { + let task = all_tasks.find(id).get_result::(c)?; + let new_status = !task.completed; + let updated_task = diesel::update(all_tasks.find(id)); + updated_task.set(task_completed.eq(new_status)).execute(c) + }).await } /// Returns the number of affected rows: 1. - pub fn delete_with_id(id: i32, conn: &SqliteConnection) -> QueryResult { - diesel::delete(all_tasks.find(id)).execute(conn) + pub async fn delete_with_id(id: i32, conn: DbConn) -> QueryResult { + conn.run(move |c| diesel::delete(all_tasks.find(id)).execute(c)).await } /// Returns the number of affected rows. #[cfg(test)] - pub fn delete_all(conn: &SqliteConnection) -> QueryResult { - diesel::delete(all_tasks).execute(conn) + pub async fn delete_all(conn: DbConn) -> QueryResult { + conn.run(|c| diesel::delete(all_tasks).execute(c)).await } } diff --git a/examples/todo/src/tests.rs b/examples/todo/src/tests.rs index f4f72f3a..73930b80 100644 --- a/examples/todo/src/tests.rs +++ b/examples/todo/src/tests.rs @@ -3,7 +3,7 @@ use super::task::Task; use parking_lot::Mutex; use rand::{Rng, thread_rng, distributions::Alphanumeric}; -use rocket::local::blocking::Client; +use rocket::local::asynchronous::Client; use rocket::http::{Status, ContentType}; // We use a lock to synchronize between tests so DB operations don't collide. @@ -15,13 +15,16 @@ macro_rules! run_test { (|$client:ident, $conn:ident| $block:expr) => ({ let _lock = DB_LOCK.lock(); - let rocket = super::rocket(); - let $client = Client::new(rocket).expect("Rocket client"); - let db = super::DbConn::get_one($client.cargo()); - let $conn = db.expect("failed to get database connection for testing"); - Task::delete_all(&$conn).expect("failed to delete all tasks for testing"); + rocket::async_test(async move { + let rocket = super::rocket(); + let $client = Client::new(rocket).await.expect("Rocket client"); + let db = super::DbConn::get_one($client.cargo()).await; + let mut $conn = db.expect("failed to get database connection for testing"); + let delete_conn = $conn.clone().await.expect("failed to get a second database connection for testing"); + Task::delete_all(delete_conn).await.expect("failed to delete all tasks for testing"); - $block + $block + }) }) } @@ -29,16 +32,17 @@ macro_rules! run_test { fn test_insertion_deletion() { run_test!(|client, conn| { // Get the tasks before making changes. - let init_tasks = Task::all(&conn).unwrap(); + let init_tasks = Task::all(conn.clone().await.unwrap()).await.unwrap(); // Issue a request to insert a new task. client.post("/todo") .header(ContentType::Form) .body("description=My+first+task") - .dispatch(); + .dispatch() + .await; // Ensure we have one more task in the database. - let new_tasks = Task::all(&conn).unwrap(); + let new_tasks = Task::all(conn.clone().await.unwrap()).await.unwrap(); assert_eq!(new_tasks.len(), init_tasks.len() + 1); // Ensure the task is what we expect. @@ -47,10 +51,10 @@ fn test_insertion_deletion() { // Issue a request to delete the task. let id = new_tasks[0].id.unwrap(); - client.delete(format!("/todo/{}", id)).dispatch(); + client.delete(format!("/todo/{}", id)).dispatch().await; // Ensure it's gone. - let final_tasks = Task::all(&conn).unwrap(); + let final_tasks = Task::all(conn).await.unwrap(); assert_eq!(final_tasks.len(), init_tasks.len()); if final_tasks.len() > 0 { assert_ne!(final_tasks[0].description, "My first task"); @@ -65,18 +69,19 @@ fn test_toggle() { client.post("/todo") .header(ContentType::Form) .body("description=test_for_completion") - .dispatch(); + .dispatch() + .await; - let task = Task::all(&conn).unwrap()[0].clone(); + let task = Task::all(conn.clone().await.unwrap()).await.unwrap()[0].clone(); assert_eq!(task.completed, false); // Issue a request to toggle the task; ensure it is completed. - client.put(format!("/todo/{}", task.id.unwrap())).dispatch(); - assert_eq!(Task::all(&conn).unwrap()[0].completed, true); + client.put(format!("/todo/{}", task.id.unwrap())).dispatch().await; + assert_eq!(Task::all(conn.clone().await.unwrap()).await.unwrap()[0].completed, true); // Issue a request to toggle the task; ensure it's not completed again. - client.put(format!("/todo/{}", task.id.unwrap())).dispatch(); - assert_eq!(Task::all(&conn).unwrap()[0].completed, false); + client.put(format!("/todo/{}", task.id.unwrap())).dispatch().await; + assert_eq!(Task::all(conn).await.unwrap()[0].completed, false); }) } @@ -86,7 +91,7 @@ fn test_many_insertions() { run_test!(|client, conn| { // Get the number of tasks initially. - let init_num = Task::all(&conn).unwrap().len(); + let init_num = Task::all(conn.clone().await.unwrap()).await.unwrap().len(); let mut descs = Vec::new(); for i in 0..ITER { @@ -95,13 +100,14 @@ fn test_many_insertions() { client.post("/todo") .header(ContentType::Form) .body(format!("description={}", desc)) - .dispatch(); + .dispatch() + .await; // Record the description we choose for this iteration. descs.insert(0, desc); // Ensure the task was inserted properly and all other tasks remain. - let tasks = Task::all(&conn).unwrap(); + let tasks = Task::all(conn.clone().await.unwrap()).await.unwrap(); assert_eq!(tasks.len(), init_num + i + 1); for j in 0..i { @@ -117,7 +123,8 @@ fn test_bad_form_submissions() { // Submit an empty form. We should get a 422 but no flash error. let res = client.post("/todo") .header(ContentType::Form) - .dispatch(); + .dispatch() + .await; let mut cookies = res.headers().get("Set-Cookie"); assert_eq!(res.status(), Status::UnprocessableEntity); @@ -128,7 +135,8 @@ fn test_bad_form_submissions() { let res = client.post("/todo") .header(ContentType::Form) .body("description=") - .dispatch(); + .dispatch() + .await; let mut cookies = res.headers().get("Set-Cookie"); assert!(cookies.any(|value| value.contains("error"))); @@ -137,7 +145,8 @@ fn test_bad_form_submissions() { let res = client.post("/todo") .header(ContentType::Form) .body("evil=smile") - .dispatch(); + .dispatch() + .await; let mut cookies = res.headers().get("Set-Cookie"); assert_eq!(res.status(), Status::UnprocessableEntity); diff --git a/site/guide/6-state.md b/site/guide/6-state.md index d941bb1e..60495134 100644 --- a/site/guide/6-state.md +++ b/site/guide/6-state.md @@ -205,7 +205,6 @@ request-local state to implement request timing. [`FromRequest` request-local state]: @api/rocket/request/trait.FromRequest.html#request-local-state [`Fairing`]: @api/rocket/fairing/trait.Fairing.html#request-local-state - ## Databases Rocket includes built-in, ORM-agnostic support for databases. In particular, @@ -222,7 +221,7 @@ three simple steps: 1. Configure the databases in `Rocket.toml`. 2. Associate a request guard type and fairing with each database. - 3. Use the request guard to retrieve a connection in a handler. + 3. Use the request guard to retrieve and use a connection in a handler. Presently, Rocket provides built-in support for the following databases: @@ -301,7 +300,7 @@ fn rocket() -> rocket::Rocket { ``` That's it! Whenever a connection to the database is needed, use your type as a -request guard: +request guard. The database can be accessed by calling the `run` method: ```rust # #[macro_use] extern crate rocket; @@ -315,9 +314,9 @@ request guard: # type Logs = (); #[get("/logs/")] -fn get_logs(conn: LogsDbConn, id: usize) -> Logs { +async fn get_logs(conn: LogsDbConn, id: usize) -> Logs { # /* - logs::filter(id.eq(log_id)).load(&*conn) + conn.run(|c| logs::filter(id.eq(log_id)).load(c)).await # */ } ``` @@ -329,6 +328,15 @@ fn get_logs(conn: LogsDbConn, id: usize) -> Logs { syntax. Rocket does not provide an ORM. It is up to you to decide how to model your application's data. +! note + + The database engines supported by `#[database]` are *synchronous*. Normally, + using such a database would block the thread of execution. To prevent this, + the `run()` function automatically uses a thread pool so that database access + does not interfere with other in-flight requests. See [Cooperative + Multitasking](../overview/#cooperative-multitasking) for more information on + why this is necessary. + If your application uses features of a database engine that are not available by default, for example support for `chrono` or `uuid`, you may enable those features by adding them in `Cargo.toml` like so: