Use 'spawn_blocking' in '#[database]'.

The connection guard type generated by `#[database]` no longer
implements `Deref` and `DerefMut`. Instead, it provides an `async fn
run()` that gives access to the underlying connection on a closure run
through `spawn_blocking()`.

Additionally moves most of the implementation of `#[database]` out
of generated code and into library code for better type-checking.
This commit is contained in:
Jeb Rosen 2020-07-11 11:17:43 -07:00 committed by Sergio Benitez
parent f7eacb6a65
commit bc8c5b9ee2
9 changed files with 412 additions and 200 deletions

View File

@ -64,23 +64,16 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result<TokenStrea
let name = &invocation.db_name; let name = &invocation.db_name;
let guard_type = &invocation.type_name; let guard_type = &invocation.type_name;
let vis = &invocation.visibility; let vis = &invocation.visibility;
let pool_type = Ident::new(&format!("{}Pool", guard_type), guard_type.span());
let fairing_name = format!("'{}' Database Pool", name); let fairing_name = format!("'{}' Database Pool", name);
let span = conn_type.span().into(); let span = conn_type.span().into();
// A few useful paths. // A few useful paths.
let databases = quote_spanned!(span => ::rocket_contrib::databases); let databases = quote_spanned!(span => ::rocket_contrib::databases);
let Poolable = quote_spanned!(span => #databases::Poolable);
let r2d2 = quote_spanned!(span => #databases::r2d2);
let spawn_blocking = quote_spanned!(span => #databases::spawn_blocking);
let request = quote!(::rocket::request); let request = quote!(::rocket::request);
let generated_types = quote_spanned! { span => let generated_types = quote_spanned! { span =>
/// The request guard type. /// The request guard type.
#vis struct #guard_type(pub #r2d2::PooledConnection<<#conn_type as #Poolable>::Manager>); #vis struct #guard_type(#databases::Connection<Self, #conn_type>);
/// The pool type.
#vis struct #pool_type(#r2d2::Pool<<#conn_type as #Poolable>::Manager>);
}; };
Ok(quote! { Ok(quote! {
@ -90,53 +83,31 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result<TokenStrea
/// Returns a fairing that initializes the associated database /// Returns a fairing that initializes the associated database
/// connection pool. /// connection pool.
pub fn fairing() -> impl ::rocket::fairing::Fairing { pub fn fairing() -> impl ::rocket::fairing::Fairing {
use #databases::Poolable; <#databases::ConnectionPool<Self, #conn_type>>::fairing(#fairing_name, #name)
::rocket::fairing::AdHoc::on_attach(#fairing_name, |mut rocket| async {
let pool = #databases::database_config(#name, rocket.config().await)
.map(<#conn_type>::pool);
match pool {
Ok(Ok(p)) => Ok(rocket.manage(#pool_type(p))),
Err(config_error) => {
::rocket::logger::error(
&format!("Database configuration failure: '{}'", #name));
::rocket::logger::error_(&format!("{}", config_error));
Err(rocket)
},
Ok(Err(pool_error)) => {
::rocket::logger::error(
&format!("Failed to initialize pool for '{}'", #name));
::rocket::logger::error_(&format!("{:?}", pool_error));
Err(rocket)
},
}
})
} }
/// Retrieves a connection of type `Self` from the `rocket` /// Retrieves a connection of type `Self` from the `rocket`
/// instance. Returns `Some` as long as `Self::fairing()` has been /// instance. Returns `Some` as long as `Self::fairing()` has been
/// attached and there is at least one connection in the pool. /// attached.
pub fn get_one(cargo: &::rocket::Cargo) -> Option<Self> { pub async fn get_one(cargo: &::rocket::Cargo) -> Option<Self> {
cargo.state::<#pool_type>() <#databases::ConnectionPool<Self, #conn_type>>::get_one(cargo).await.map(Self)
.and_then(|pool| pool.0.get().ok())
.map(#guard_type)
}
} }
impl ::std::ops::Deref for #guard_type { /// Runs the provided closure on a thread from a threadpool. The
type Target = #conn_type; /// closure will be passed an `&mut r2d2::PooledConnection`.
/// `.await`ing the return value of this function yields the value
#[inline(always)] /// returned by the closure.
fn deref(&self) -> &Self::Target { pub async fn run<F, R>(self, f: F) -> R
&self.0 where
} F: FnOnce(&mut #conn_type) -> R + Send + 'static,
R: Send + 'static,
{
self.0.run(f).await
} }
impl ::std::ops::DerefMut for #guard_type { /// Asynchronously acquires another connection from the connection pool.
#[inline(always)] pub async fn clone(&mut self) -> ::std::result::Result<Self, ()> {
fn deref_mut(&mut self) -> &mut Self::Target { self.0.clone().await.map(Self)
&mut self.0
} }
} }
@ -145,20 +116,8 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result<TokenStrea
type Error = (); type Error = ();
async fn from_request(request: &'a #request::Request<'r>) -> #request::Outcome<Self, ()> { async fn from_request(request: &'a #request::Request<'r>) -> #request::Outcome<Self, ()> {
use ::rocket::http::Status; <#databases::Connection<Self, #conn_type>>::from_request(request).await.map(Self)
let guard = request.guard::<::rocket::State<'_, #pool_type>>();
let pool = ::rocket::try_outcome!(guard.await).0.clone();
#spawn_blocking(move || {
match pool.get() {
Ok(conn) => #request::Outcome::Success(#guard_type(conn)),
Err(_) => #request::Outcome::Failure((Status::ServiceUnavailable, ())),
}
}).await.expect("failed to spawn a blocking task to get a pooled connection")
} }
} }
// TODO.async: What about spawn_blocking on drop?
}.into()) }.into())
} }

View File

@ -1,11 +1,21 @@
error[E0277]: the trait bound `Unknown: Poolable` is not satisfied error[E0277]: the trait bound `Unknown: rocket_contrib::databases::Poolable` is not satisfied
--> $DIR/database-types.rs:7:10 --> $DIR/database-types.rs:7:10
| |
7 | struct A(Unknown); 7 | struct A(Unknown);
| ^^^^^^^ the trait `Poolable` is not implemented for `Unknown` | ^^^^^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `Unknown`
|
::: $WORKSPACE/contrib/lib/src/databases.rs:832:29
|
832 | pub struct Connection<K, C: Poolable> {
| -------- required by this bound in `rocket_contrib::databases::Connection`
error[E0277]: the trait bound `Vec<i32>: Poolable` is not satisfied error[E0277]: the trait bound `std::vec::Vec<i32>: rocket_contrib::databases::Poolable` is not satisfied
--> $DIR/database-types.rs:10:10 --> $DIR/database-types.rs:10:10
| |
10 | struct B(Vec<i32>); 10 | struct B(Vec<i32>);
| ^^^^^^^^ the trait `Poolable` is not implemented for `Vec<i32>` | ^^^^^^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `std::vec::Vec<i32>`
|
::: $WORKSPACE/contrib/lib/src/databases.rs:832:29
|
832 | pub struct Connection<K, C: Poolable> {
| -------- required by this bound in `rocket_contrib::databases::Connection`

View File

@ -3,9 +3,19 @@ error[E0277]: the trait bound `Unknown: rocket_contrib::databases::Poolable` is
| |
7 | struct A(Unknown); 7 | struct A(Unknown);
| ^^^^^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `Unknown` | ^^^^^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `Unknown`
|
::: $WORKSPACE/contrib/lib/src/databases.rs
|
| pub struct Connection<K, C: Poolable> {
| -------- required by this bound in `rocket_contrib::databases::Connection`
error[E0277]: the trait bound `std::vec::Vec<i32>: rocket_contrib::databases::Poolable` is not satisfied error[E0277]: the trait bound `std::vec::Vec<i32>: rocket_contrib::databases::Poolable` is not satisfied
--> $DIR/database-types.rs:10:10 --> $DIR/database-types.rs:10:10
| |
10 | struct B(Vec<i32>); 10 | struct B(Vec<i32>);
| ^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `std::vec::Vec<i32>` | ^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `std::vec::Vec<i32>`
|
::: $WORKSPACE/contrib/lib/src/databases.rs
|
| pub struct Connection<K, C: Poolable> {
| -------- required by this bound in `rocket_contrib::databases::Connection`

View File

@ -84,9 +84,9 @@
//! # type Result<T> = std::result::Result<T, ()>; //! # type Result<T> = std::result::Result<T, ()>;
//! # //! #
//! #[get("/logs/<id>")] //! #[get("/logs/<id>")]
//! fn get_logs(conn: LogsDbConn, id: usize) -> Result<Logs> { //! async fn get_logs(conn: LogsDbConn, id: usize) -> Result<Logs> {
//! # /* //! # /*
//! Logs::by_id(&*conn, id) //! conn.run(|c| Logs::by_id(c, id)).await
//! # */ //! # */
//! # Ok(()) //! # Ok(())
//! } //! }
@ -224,9 +224,7 @@
//! The macro generates a [`FromRequest`] implementation for the decorated type, //! The macro generates a [`FromRequest`] implementation for the decorated type,
//! allowing the type to be used as a request guard. This implementation //! allowing the type to be used as a request guard. This implementation
//! retrieves a connection from the database pool or fails with a //! retrieves a connection from the database pool or fails with a
//! `Status::ServiceUnavailable` if no connections are available. The macro also //! `Status::ServiceUnavailable` if connecting to the database times out.
//! generates an implementation of the [`Deref`](std::ops::Deref) trait with
//! the internal `Poolable` type as the target.
//! //!
//! The macro will also generate two inherent methods on the decorated type: //! The macro will also generate two inherent methods on the decorated type:
//! //!
@ -235,11 +233,10 @@
//! Returns a fairing that initializes the associated database connection //! Returns a fairing that initializes the associated database connection
//! pool. //! pool.
//! //!
//! * `fn get_one(&Cargo) -> Option<Self>` //! * `async fn get_one(&Cargo) -> Option<Self>`
//! //!
//! Retrieves a connection from the configured pool. Returns `Some` as long //! Retrieves a connection wrapper from the configured pool. Returns `Some`
//! as `Self::fairing()` has been attached and there is at least one //! as long as `Self::fairing()` has been attached.
//! connection in the pool.
//! //!
//! The fairing returned from the generated `fairing()` method _must_ be //! The fairing returned from the generated `fairing()` method _must_ be
//! attached for the request guard implementation to succeed. Putting the pieces //! attached for the request guard implementation to succeed. Putting the pieces
@ -280,8 +277,8 @@
//! //!
//! ## Handlers //! ## Handlers
//! //!
//! Finally, simply use your type as a request guard in a handler to retrieve a //! Finally, use your type as a request guard in a handler to retrieve a
//! connection to a given database: //! connection wrapper for the database:
//! //!
//! ```rust //! ```rust
//! # #[macro_use] extern crate rocket; //! # #[macro_use] extern crate rocket;
@ -300,8 +297,7 @@
//! # } //! # }
//! ``` //! ```
//! //!
//! The generated `Deref` implementation allows easy access to the inner //! A connection can be retrieved and used with the `run()` method:
//! connection type:
//! //!
//! ```rust //! ```rust
//! # #[macro_use] extern crate rocket; //! # #[macro_use] extern crate rocket;
@ -320,12 +316,43 @@
//! } //! }
//! //!
//! #[get("/")] //! #[get("/")]
//! fn my_handler(conn: MyDatabase) -> Data { //! async fn my_handler(mut conn: MyDatabase) -> Data {
//! load_from_db(&*conn) //! conn.run(|c| load_from_db(c)).await
//! } //! }
//! # } //! # }
//! ``` //! ```
//! //!
//! `run()` takes the connection by value. To make multiple calls to run,
//! obtain a second connection first with `clone()`:
//!
//! ```rust
//! # #[macro_use] extern crate rocket;
//! # #[macro_use] extern crate rocket_contrib;
//! #
//! # #[cfg(feature = "diesel_sqlite_pool")]
//! # mod test {
//! # use rocket_contrib::databases::diesel;
//! # type Data = ();
//! #[database("my_db")]
//! struct MyDatabase(diesel::SqliteConnection);
//!
//! fn load_from_db(conn: &diesel::SqliteConnection) -> Data {
//! // Do something with connection, return some data.
//! # ()
//! }
//!
//! #[get("/")]
//! async fn my_handler(mut conn: MyDatabase) -> Data {
//! let cloned = conn.clone().await.expect("");
//! cloned.run(|c| load_from_db(c)).await;
//!
//! // Do something else
//! conn.run(|c| load_from_db(c)).await;
//! }
//! # }
//! ```
//!
//!
//! # Database Support //! # Database Support
//! //!
//! Built-in support is provided for many popular databases and drivers. Support //! Built-in support is provided for many popular databases and drivers. Support
@ -380,18 +407,21 @@
pub extern crate r2d2; pub extern crate r2d2;
#[doc(hidden)]
pub use tokio::task::spawn_blocking;
#[cfg(any(feature = "diesel_sqlite_pool", #[cfg(any(feature = "diesel_sqlite_pool",
feature = "diesel_postgres_pool", feature = "diesel_postgres_pool",
feature = "diesel_mysql_pool"))] feature = "diesel_mysql_pool"))]
pub extern crate diesel; pub extern crate diesel;
use std::fmt::{self, Display, Formatter}; use std::fmt::{self, Display, Formatter};
use std::marker::{Send, Sized}; use std::marker::PhantomData;
use std::sync::Arc;
use rocket::config::{self, Value}; use rocket::config::{self, Value};
use rocket::fairing::{AdHoc, Fairing};
use rocket::http::Status;
use rocket::request::Outcome;
use rocket::tokio::sync::{OwnedSemaphorePermit, Semaphore};
use self::r2d2::ManageConnection; use self::r2d2::ManageConnection;
@ -688,7 +718,7 @@ pub trait Poolable: Send + Sized + 'static {
type Manager: ManageConnection<Connection=Self>; type Manager: ManageConnection<Connection=Self>;
/// The associated error type in the event that constructing the connection /// The associated error type in the event that constructing the connection
/// manager and/or the connection pool fails. /// manager and/or the connection pool fails.
type Error; type Error: std::fmt::Debug;
/// Creates an `r2d2` connection pool for `Manager::Connection`, returning /// Creates an `r2d2` connection pool for `Manager::Connection`, returning
/// the pool on success. /// the pool on success.
@ -780,6 +810,184 @@ impl Poolable for memcache::Client {
} }
} }
/// Unstable internal details of generated code for the #[database] attribute.
///
/// This type is implemented here instead of in generated code to ensure all
/// types are properly checked.
#[doc(hidden)]
pub struct ConnectionPool<K, C: Poolable> {
pool: r2d2::Pool<C::Manager>,
semaphore: Arc<Semaphore>,
_marker: PhantomData<fn() -> K>,
}
/// Unstable internal details of generated code for the #[database] attribute.
///
/// This type is implemented here instead of in generated code to ensure all
/// types are properly checked.
#[doc(hidden)]
pub struct Connection<K, C: Poolable> {
pool: ConnectionPool<K, C>,
connection: Option<r2d2::PooledConnection<C::Manager>>,
_permit: Option<OwnedSemaphorePermit>,
_marker: PhantomData<fn() -> K>,
}
// A wrapper around spawn_blocking that propagates panics to the calling code
async fn run_blocking<F, R>(job: F) -> R
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
match tokio::task::spawn_blocking(job).await {
Ok(ret) => ret,
Err(e) => match e.try_into_panic() {
Ok(panic) => std::panic::resume_unwind(panic),
Err(_) => unreachable!("spawn_blocking tasks are never canceled"),
}
}
}
impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
#[inline]
pub fn fairing(fairing_name: &'static str, config_name: &'static str) -> impl Fairing {
AdHoc::on_attach(fairing_name, move |mut rocket| async move {
let config = database_config(config_name, rocket.config().await);
let pool = config.map(|c| (c.pool_size, C::pool(c)));
match pool {
Ok((size, Ok(pool))) => {
let managed = ConnectionPool::<K, C> {
pool,
semaphore: Arc::new(Semaphore::new(size as usize)),
_marker: PhantomData,
};
Ok(rocket.manage(managed))
},
Err(config_error) => {
rocket::logger::error(
&format!("Database configuration failure: '{}'", config_name));
rocket::logger::error_(&config_error.to_string());
Err(rocket)
},
Ok((_, Err(pool_error))) => {
rocket::logger::error(
&format!("Failed to initialize pool for '{}'", config_name));
rocket::logger::error_(&format!("{:?}", pool_error));
Err(rocket)
},
}
})
}
async fn get(&self) -> Result<Connection<K, C>, ()> {
// TODO: timeout duration
let permit = match tokio::time::timeout(
self.pool.connection_timeout(),
self.semaphore.clone().acquire_owned()
).await {
Ok(p) => p,
Err(_) => {
error_!("Failed to get a database connection within the timeout.");
return Err(());
}
};
let pool = self.pool.clone();
// TODO: timeout duration
match run_blocking(move || pool.get_timeout(std::time::Duration::from_secs(0))).await {
Ok(c) => {
Ok(Connection {
pool: self.clone(),
connection: Some(c),
_permit: Some(permit),
_marker: PhantomData,
})
}
Err(e) => {
error_!("Failed to get a database connection: {}", e);
Err(())
}
}
}
#[inline]
pub async fn get_one(cargo: &rocket::Cargo) -> Option<Connection<K, C>> {
match cargo.state::<Self>() {
Some(pool) => pool.get().await.ok(),
None => {
error_!("Database fairing was not attached for {}", std::any::type_name::<K>());
None
}
}
}
}
impl<K, C: Poolable> Clone for ConnectionPool<K, C> {
fn clone(&self) -> Self {
Self {
pool: self.pool.clone(),
semaphore: self.semaphore.clone(),
_marker: PhantomData,
}
}
}
impl<K: 'static, C: Poolable> Connection<K, C> {
#[inline]
pub async fn run<F, R>(self, f: F) -> R
where F: FnOnce(&mut C) -> R + Send + 'static,
R: Send + 'static,
{
run_blocking(move || {
// 'self' contains a semaphore permit that should be held throughout
// this entire spawned task. Explicitly move it, so that a
// refactoring won't accidentally release the permit too early.
let mut this: Self = self;
let mut conn = this.connection.take()
.expect("internal invariant broken: self.connection is Some");
f(&mut conn)
}).await
}
#[inline]
pub async fn clone(&mut self) -> Result<Self, ()> {
self.pool.get().await
}
}
impl<K, C: Poolable> Drop for Connection<K, C> {
fn drop(&mut self) {
if let Some(conn) = self.connection.take() {
tokio::task::spawn_blocking(|| drop(conn));
}
}
}
#[rocket::async_trait]
impl<'a, 'r, K: 'static, C: Poolable> rocket::request::FromRequest<'a, 'r> for Connection<K, C> {
type Error = ();
#[inline]
async fn from_request(request: &'a rocket::request::Request<'r>) -> Outcome<Self, ()> {
match request.managed_state::<ConnectionPool<K, C>>() {
Some(inner) => {
match inner.get().await {
Ok(c) => Outcome::Success(c),
Err(()) => Outcome::Failure((Status::ServiceUnavailable, ())),
}
}
None => {
error_!("Database fairing was not attached for {}", std::any::type_name::<K>());
Outcome::Failure((Status::InternalServerError, ()))
}
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::collections::BTreeMap; use std::collections::BTreeMap;

View File

@ -21,40 +21,31 @@ mod rusqlite_integration_test {
#[database("test_db")] #[database("test_db")]
struct SqliteDb(pub rusqlite::Connection); struct SqliteDb(pub rusqlite::Connection);
// Test to ensure that multiple databases of the same type can be used
#[database("test_db_2")]
struct SqliteDb2(pub rusqlite::Connection);
#[rocket::async_test] #[rocket::async_test]
async fn deref_mut_impl_present() { async fn test_db() {
let mut test_db: Map<String, Value> = Map::new(); let mut test_db: Map<String, Value> = Map::new();
let mut test_db_opts: Map<String, Value> = Map::new(); let mut test_db_opts: Map<String, Value> = Map::new();
test_db_opts.insert("url".into(), Value::String(":memory:".into())); test_db_opts.insert("url".into(), Value::String(":memory:".into()));
test_db.insert("test_db".into(), Value::Table(test_db_opts)); test_db.insert("test_db".into(), Value::Table(test_db_opts.clone()));
test_db.insert("test_db_2".into(), Value::Table(test_db_opts));
let config = Config::build(Environment::Development) let config = Config::build(Environment::Development)
.extra("databases", Value::Table(test_db)) .extra("databases", Value::Table(test_db))
.finalize() .finalize()
.unwrap(); .unwrap();
let mut rocket = rocket::custom(config).attach(SqliteDb::fairing()); let mut rocket = rocket::custom(config).attach(SqliteDb::fairing()).attach(SqliteDb2::fairing());
let mut conn = SqliteDb::get_one(rocket.inspect().await).expect("unable to get connection"); let conn = SqliteDb::get_one(rocket.inspect().await).await.expect("unable to get connection");
// Rusqlite's `transaction()` method takes `&mut self`; this tests the // Rusqlite's `transaction()` method takes `&mut self`; this tests that
// presence of a `DerefMut` trait on the generated connection type. // the &mut method can be called inside the closure passed to `run()`.
conn.run(|conn| {
let tx = conn.transaction().unwrap(); let tx = conn.transaction().unwrap();
let _: i32 = tx.query_row("SELECT 1", &[] as &[&dyn ToSql], |row| row.get(0)).expect("get row"); let _: i32 = tx.query_row("SELECT 1", &[] as &[&dyn ToSql], |row| row.get(0)).expect("get row");
tx.commit().expect("committed transaction"); tx.commit().expect("committed transaction");
} }).await;
#[rocket::async_test]
async fn deref_impl_present() {
let mut test_db: Map<String, Value> = Map::new();
let mut test_db_opts: Map<String, Value> = Map::new();
test_db_opts.insert("url".into(), Value::String(":memory:".into()));
test_db.insert("test_db".into(), Value::Table(test_db_opts));
let config = Config::build(Environment::Development)
.extra("databases", Value::Table(test_db))
.finalize()
.unwrap();
let mut rocket = rocket::custom(config).attach(SqliteDb::fairing());
let conn = SqliteDb::get_one(rocket.inspect().await).expect("unable to get connection");
let _: i32 = conn.query_row("SELECT 1", &[] as &[&dyn ToSql], |row| row.get(0)).expect("get row");
} }
} }

View File

@ -25,23 +25,26 @@ embed_migrations!();
pub struct DbConn(SqliteConnection); pub struct DbConn(SqliteConnection);
#[derive(Debug, serde::Serialize)] #[derive(Debug, serde::Serialize)]
struct Context<'a> { struct Context {
msg: Option<(&'a str, &'a str)>, msg: Option<(String, String)>,
tasks: Vec<Task> tasks: Vec<Task>
} }
impl<'a> Context<'a> { impl Context {
pub fn err(conn: &DbConn, msg: &'a str) -> Context<'a> { pub async fn err(conn: DbConn, msg: String) -> Context {
Context { msg: Some(("error", msg)), tasks: Task::all(conn).unwrap_or_default() } Context {
msg: Some(("error".to_string(), msg)),
tasks: Task::all(conn).await.unwrap_or_default()
}
} }
pub fn raw(conn: &DbConn, msg: Option<(&'a str, &'a str)>) -> Context<'a> { pub async fn raw(conn: DbConn, msg: Option<(String, String)>) -> Context {
match Task::all(conn) { match Task::all(conn).await {
Ok(tasks) => Context { msg, tasks }, Ok(tasks) => Context { msg, tasks },
Err(e) => { Err(e) => {
error_!("DB Task::all() error: {}", e); error_!("DB Task::all() error: {}", e);
Context { Context {
msg: Some(("error", "Couldn't access the task database.")), msg: Some(("error".to_string(), "Couldn't access the task database.".to_string())),
tasks: vec![] tasks: vec![]
} }
} }
@ -50,11 +53,11 @@ impl<'a> Context<'a> {
} }
#[post("/", data = "<todo_form>")] #[post("/", data = "<todo_form>")]
fn new(todo_form: Form<Todo>, conn: DbConn) -> Flash<Redirect> { async fn new(todo_form: Form<Todo>, conn: DbConn) -> Flash<Redirect> {
let todo = todo_form.into_inner(); let todo = todo_form.into_inner();
if todo.description.is_empty() { if todo.description.is_empty() {
Flash::error(Redirect::to("/"), "Description cannot be empty.") Flash::error(Redirect::to("/"), "Description cannot be empty.")
} else if let Err(e) = Task::insert(todo, &conn) { } else if let Err(e) = Task::insert(todo, conn).await {
error_!("DB insertion error: {}", e); error_!("DB insertion error: {}", e);
Flash::error(Redirect::to("/"), "Todo could not be inserted due an internal error.") Flash::error(Redirect::to("/"), "Todo could not be inserted due an internal error.")
} else { } else {
@ -63,42 +66,48 @@ fn new(todo_form: Form<Todo>, conn: DbConn) -> Flash<Redirect> {
} }
#[put("/<id>")] #[put("/<id>")]
fn toggle(id: i32, conn: DbConn) -> Result<Redirect, Template> { async fn toggle(id: i32, mut conn: DbConn) -> Result<Redirect, Template> {
Task::toggle_with_id(id, &conn) // TODO
.map(|_| Redirect::to("/")) let conn2 = conn.clone().await.unwrap();
.map_err(|e| { match Task::toggle_with_id(id, conn).await {
Ok(_) => Ok(Redirect::to("/")),
Err(e) => {
error_!("DB toggle({}) error: {}", id, e); error_!("DB toggle({}) error: {}", id, e);
Template::render("index", Context::err(&conn, "Failed to toggle task.")) Err(Template::render("index", Context::err(conn2, "Failed to toggle task.".to_string()).await))
}) }
}
} }
#[delete("/<id>")] #[delete("/<id>")]
fn delete(id: i32, conn: DbConn) -> Result<Flash<Redirect>, Template> { async fn delete(id: i32, mut conn: DbConn) -> Result<Flash<Redirect>, Template> {
Task::delete_with_id(id, &conn) // TODO
.map(|_| Flash::success(Redirect::to("/"), "Todo was deleted.")) let conn2 = conn.clone().await.unwrap();
.map_err(|e| { match Task::delete_with_id(id, conn).await {
Ok(_) => Ok(Flash::success(Redirect::to("/"), "Todo was deleted.")),
Err(e) => {
error_!("DB deletion({}) error: {}", id, e); error_!("DB deletion({}) error: {}", id, e);
Template::render("index", Context::err(&conn, "Failed to delete task.")) Err(Template::render("index", Context::err(conn2, "Failed to delete task.".to_string()).await))
}) }
}
} }
#[get("/")] #[get("/")]
fn index(msg: Option<FlashMessage<'_, '_>>, conn: DbConn) -> Template { async fn index(msg: Option<FlashMessage<'_, '_>>, conn: DbConn) -> Template {
Template::render("index", match msg { let msg = msg.map(|m| (m.name().to_string(), m.msg().to_string()));
Some(ref msg) => Context::raw(&conn, Some((msg.name(), msg.msg()))), Template::render("index", Context::raw(conn, msg).await)
None => Context::raw(&conn, None),
})
} }
async fn run_db_migrations(mut rocket: Rocket) -> Result<Rocket, Rocket> { async fn run_db_migrations(mut rocket: Rocket) -> Result<Rocket, Rocket> {
let conn = DbConn::get_one(rocket.inspect().await).expect("database connection"); let conn = DbConn::get_one(rocket.inspect().await).await.expect("database connection");
match embedded_migrations::run(&*conn) { conn.run(|c| {
match embedded_migrations::run(c) {
Ok(()) => Ok(rocket), Ok(()) => Ok(rocket),
Err(e) => { Err(e) => {
error!("Failed to run database migrations: {:?}", e); error!("Failed to run database migrations: {:?}", e);
Err(rocket) Err(rocket)
} }
} }
}).await
} }
#[launch] #[launch]

View File

@ -13,6 +13,8 @@ mod schema {
use self::schema::tasks; use self::schema::tasks;
use self::schema::tasks::dsl::{tasks as all_tasks, completed as task_completed}; use self::schema::tasks::dsl::{tasks as all_tasks, completed as task_completed};
use crate::DbConn;
#[table_name="tasks"] #[table_name="tasks"]
#[derive(serde::Serialize, Queryable, Insertable, Debug, Clone)] #[derive(serde::Serialize, Queryable, Insertable, Debug, Clone)]
pub struct Task { pub struct Task {
@ -27,32 +29,38 @@ pub struct Todo {
} }
impl Task { impl Task {
pub fn all(conn: &SqliteConnection) -> QueryResult<Vec<Task>> { pub async fn all(conn: DbConn) -> QueryResult<Vec<Task>> {
all_tasks.order(tasks::id.desc()).load::<Task>(conn) conn.run(|c| {
all_tasks.order(tasks::id.desc()).load::<Task>(c)
}).await
} }
/// Returns the number of affected rows: 1. /// Returns the number of affected rows: 1.
pub fn insert(todo: Todo, conn: &SqliteConnection) -> QueryResult<usize> { pub async fn insert(todo: Todo, conn: DbConn) -> QueryResult<usize> {
conn.run(|c| {
let t = Task { id: None, description: todo.description, completed: false }; let t = Task { id: None, description: todo.description, completed: false };
diesel::insert_into(tasks::table).values(&t).execute(conn) diesel::insert_into(tasks::table).values(&t).execute(c)
}).await
} }
/// Returns the number of affected rows: 1. /// Returns the number of affected rows: 1.
pub fn toggle_with_id(id: i32, conn: &SqliteConnection) -> QueryResult<usize> { pub async fn toggle_with_id(id: i32, conn: DbConn) -> QueryResult<usize> {
let task = all_tasks.find(id).get_result::<Task>(conn)?; conn.run(move |c| {
let task = all_tasks.find(id).get_result::<Task>(c)?;
let new_status = !task.completed; let new_status = !task.completed;
let updated_task = diesel::update(all_tasks.find(id)); let updated_task = diesel::update(all_tasks.find(id));
updated_task.set(task_completed.eq(new_status)).execute(conn) updated_task.set(task_completed.eq(new_status)).execute(c)
}).await
} }
/// Returns the number of affected rows: 1. /// Returns the number of affected rows: 1.
pub fn delete_with_id(id: i32, conn: &SqliteConnection) -> QueryResult<usize> { pub async fn delete_with_id(id: i32, conn: DbConn) -> QueryResult<usize> {
diesel::delete(all_tasks.find(id)).execute(conn) conn.run(move |c| diesel::delete(all_tasks.find(id)).execute(c)).await
} }
/// Returns the number of affected rows. /// Returns the number of affected rows.
#[cfg(test)] #[cfg(test)]
pub fn delete_all(conn: &SqliteConnection) -> QueryResult<usize> { pub async fn delete_all(conn: DbConn) -> QueryResult<usize> {
diesel::delete(all_tasks).execute(conn) conn.run(|c| diesel::delete(all_tasks).execute(c)).await
} }
} }

View File

@ -3,7 +3,7 @@ use super::task::Task;
use parking_lot::Mutex; use parking_lot::Mutex;
use rand::{Rng, thread_rng, distributions::Alphanumeric}; use rand::{Rng, thread_rng, distributions::Alphanumeric};
use rocket::local::blocking::Client; use rocket::local::asynchronous::Client;
use rocket::http::{Status, ContentType}; use rocket::http::{Status, ContentType};
// We use a lock to synchronize between tests so DB operations don't collide. // We use a lock to synchronize between tests so DB operations don't collide.
@ -15,30 +15,34 @@ macro_rules! run_test {
(|$client:ident, $conn:ident| $block:expr) => ({ (|$client:ident, $conn:ident| $block:expr) => ({
let _lock = DB_LOCK.lock(); let _lock = DB_LOCK.lock();
rocket::async_test(async move {
let rocket = super::rocket(); let rocket = super::rocket();
let $client = Client::new(rocket).expect("Rocket client"); let $client = Client::new(rocket).await.expect("Rocket client");
let db = super::DbConn::get_one($client.cargo()); let db = super::DbConn::get_one($client.cargo()).await;
let $conn = db.expect("failed to get database connection for testing"); let mut $conn = db.expect("failed to get database connection for testing");
Task::delete_all(&$conn).expect("failed to delete all tasks for testing"); let delete_conn = $conn.clone().await.expect("failed to get a second database connection for testing");
Task::delete_all(delete_conn).await.expect("failed to delete all tasks for testing");
$block $block
}) })
})
} }
#[test] #[test]
fn test_insertion_deletion() { fn test_insertion_deletion() {
run_test!(|client, conn| { run_test!(|client, conn| {
// Get the tasks before making changes. // Get the tasks before making changes.
let init_tasks = Task::all(&conn).unwrap(); let init_tasks = Task::all(conn.clone().await.unwrap()).await.unwrap();
// Issue a request to insert a new task. // Issue a request to insert a new task.
client.post("/todo") client.post("/todo")
.header(ContentType::Form) .header(ContentType::Form)
.body("description=My+first+task") .body("description=My+first+task")
.dispatch(); .dispatch()
.await;
// Ensure we have one more task in the database. // Ensure we have one more task in the database.
let new_tasks = Task::all(&conn).unwrap(); let new_tasks = Task::all(conn.clone().await.unwrap()).await.unwrap();
assert_eq!(new_tasks.len(), init_tasks.len() + 1); assert_eq!(new_tasks.len(), init_tasks.len() + 1);
// Ensure the task is what we expect. // Ensure the task is what we expect.
@ -47,10 +51,10 @@ fn test_insertion_deletion() {
// Issue a request to delete the task. // Issue a request to delete the task.
let id = new_tasks[0].id.unwrap(); let id = new_tasks[0].id.unwrap();
client.delete(format!("/todo/{}", id)).dispatch(); client.delete(format!("/todo/{}", id)).dispatch().await;
// Ensure it's gone. // Ensure it's gone.
let final_tasks = Task::all(&conn).unwrap(); let final_tasks = Task::all(conn).await.unwrap();
assert_eq!(final_tasks.len(), init_tasks.len()); assert_eq!(final_tasks.len(), init_tasks.len());
if final_tasks.len() > 0 { if final_tasks.len() > 0 {
assert_ne!(final_tasks[0].description, "My first task"); assert_ne!(final_tasks[0].description, "My first task");
@ -65,18 +69,19 @@ fn test_toggle() {
client.post("/todo") client.post("/todo")
.header(ContentType::Form) .header(ContentType::Form)
.body("description=test_for_completion") .body("description=test_for_completion")
.dispatch(); .dispatch()
.await;
let task = Task::all(&conn).unwrap()[0].clone(); let task = Task::all(conn.clone().await.unwrap()).await.unwrap()[0].clone();
assert_eq!(task.completed, false); assert_eq!(task.completed, false);
// Issue a request to toggle the task; ensure it is completed. // Issue a request to toggle the task; ensure it is completed.
client.put(format!("/todo/{}", task.id.unwrap())).dispatch(); client.put(format!("/todo/{}", task.id.unwrap())).dispatch().await;
assert_eq!(Task::all(&conn).unwrap()[0].completed, true); assert_eq!(Task::all(conn.clone().await.unwrap()).await.unwrap()[0].completed, true);
// Issue a request to toggle the task; ensure it's not completed again. // Issue a request to toggle the task; ensure it's not completed again.
client.put(format!("/todo/{}", task.id.unwrap())).dispatch(); client.put(format!("/todo/{}", task.id.unwrap())).dispatch().await;
assert_eq!(Task::all(&conn).unwrap()[0].completed, false); assert_eq!(Task::all(conn).await.unwrap()[0].completed, false);
}) })
} }
@ -86,7 +91,7 @@ fn test_many_insertions() {
run_test!(|client, conn| { run_test!(|client, conn| {
// Get the number of tasks initially. // Get the number of tasks initially.
let init_num = Task::all(&conn).unwrap().len(); let init_num = Task::all(conn.clone().await.unwrap()).await.unwrap().len();
let mut descs = Vec::new(); let mut descs = Vec::new();
for i in 0..ITER { for i in 0..ITER {
@ -95,13 +100,14 @@ fn test_many_insertions() {
client.post("/todo") client.post("/todo")
.header(ContentType::Form) .header(ContentType::Form)
.body(format!("description={}", desc)) .body(format!("description={}", desc))
.dispatch(); .dispatch()
.await;
// Record the description we choose for this iteration. // Record the description we choose for this iteration.
descs.insert(0, desc); descs.insert(0, desc);
// Ensure the task was inserted properly and all other tasks remain. // Ensure the task was inserted properly and all other tasks remain.
let tasks = Task::all(&conn).unwrap(); let tasks = Task::all(conn.clone().await.unwrap()).await.unwrap();
assert_eq!(tasks.len(), init_num + i + 1); assert_eq!(tasks.len(), init_num + i + 1);
for j in 0..i { for j in 0..i {
@ -117,7 +123,8 @@ fn test_bad_form_submissions() {
// Submit an empty form. We should get a 422 but no flash error. // Submit an empty form. We should get a 422 but no flash error.
let res = client.post("/todo") let res = client.post("/todo")
.header(ContentType::Form) .header(ContentType::Form)
.dispatch(); .dispatch()
.await;
let mut cookies = res.headers().get("Set-Cookie"); let mut cookies = res.headers().get("Set-Cookie");
assert_eq!(res.status(), Status::UnprocessableEntity); assert_eq!(res.status(), Status::UnprocessableEntity);
@ -128,7 +135,8 @@ fn test_bad_form_submissions() {
let res = client.post("/todo") let res = client.post("/todo")
.header(ContentType::Form) .header(ContentType::Form)
.body("description=") .body("description=")
.dispatch(); .dispatch()
.await;
let mut cookies = res.headers().get("Set-Cookie"); let mut cookies = res.headers().get("Set-Cookie");
assert!(cookies.any(|value| value.contains("error"))); assert!(cookies.any(|value| value.contains("error")));
@ -137,7 +145,8 @@ fn test_bad_form_submissions() {
let res = client.post("/todo") let res = client.post("/todo")
.header(ContentType::Form) .header(ContentType::Form)
.body("evil=smile") .body("evil=smile")
.dispatch(); .dispatch()
.await;
let mut cookies = res.headers().get("Set-Cookie"); let mut cookies = res.headers().get("Set-Cookie");
assert_eq!(res.status(), Status::UnprocessableEntity); assert_eq!(res.status(), Status::UnprocessableEntity);

View File

@ -205,7 +205,6 @@ request-local state to implement request timing.
[`FromRequest` request-local state]: @api/rocket/request/trait.FromRequest.html#request-local-state [`FromRequest` request-local state]: @api/rocket/request/trait.FromRequest.html#request-local-state
[`Fairing`]: @api/rocket/fairing/trait.Fairing.html#request-local-state [`Fairing`]: @api/rocket/fairing/trait.Fairing.html#request-local-state
<!-- TODO.async: rewrite? -->
## Databases ## Databases
Rocket includes built-in, ORM-agnostic support for databases. In particular, Rocket includes built-in, ORM-agnostic support for databases. In particular,
@ -222,7 +221,7 @@ three simple steps:
1. Configure the databases in `Rocket.toml`. 1. Configure the databases in `Rocket.toml`.
2. Associate a request guard type and fairing with each database. 2. Associate a request guard type and fairing with each database.
3. Use the request guard to retrieve a connection in a handler. 3. Use the request guard to retrieve and use a connection in a handler.
Presently, Rocket provides built-in support for the following databases: Presently, Rocket provides built-in support for the following databases:
@ -301,7 +300,7 @@ fn rocket() -> rocket::Rocket {
``` ```
That's it! Whenever a connection to the database is needed, use your type as a That's it! Whenever a connection to the database is needed, use your type as a
request guard: request guard. The database can be accessed by calling the `run` method:
```rust ```rust
# #[macro_use] extern crate rocket; # #[macro_use] extern crate rocket;
@ -315,9 +314,9 @@ request guard:
# type Logs = (); # type Logs = ();
#[get("/logs/<id>")] #[get("/logs/<id>")]
fn get_logs(conn: LogsDbConn, id: usize) -> Logs { async fn get_logs(conn: LogsDbConn, id: usize) -> Logs {
# /* # /*
logs::filter(id.eq(log_id)).load(&*conn) conn.run(|c| logs::filter(id.eq(log_id)).load(c)).await
# */ # */
} }
``` ```
@ -329,6 +328,15 @@ fn get_logs(conn: LogsDbConn, id: usize) -> Logs {
syntax. Rocket does not provide an ORM. It is up to you to decide how to model syntax. Rocket does not provide an ORM. It is up to you to decide how to model
your application's data. your application's data.
! note
The database engines supported by `#[database]` are *synchronous*. Normally,
using such a database would block the thread of execution. To prevent this,
the `run()` function automatically uses a thread pool so that database access
does not interfere with other in-flight requests. See [Cooperative
Multitasking](../overview/#cooperative-multitasking) for more information on
why this is necessary.
If your application uses features of a database engine that are not available If your application uses features of a database engine that are not available
by default, for example support for `chrono` or `uuid`, you may enable those by default, for example support for `chrono` or `uuid`, you may enable those
features by adding them in `Cargo.toml` like so: features by adding them in `Cargo.toml` like so: