mirror of https://github.com/rwf2/Rocket.git
Borrow 'self' in 'Connection::run()'.
This simulates the pre-async behavior of serialization attempts to use a connection by using an `async` Mutex.
This commit is contained in:
parent
bc8c5b9ee2
commit
dee11966b6
|
@ -97,18 +97,13 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result<TokenStrea
|
||||||
/// closure will be passed an `&mut r2d2::PooledConnection`.
|
/// closure will be passed an `&mut r2d2::PooledConnection`.
|
||||||
/// `.await`ing the return value of this function yields the value
|
/// `.await`ing the return value of this function yields the value
|
||||||
/// returned by the closure.
|
/// returned by the closure.
|
||||||
pub async fn run<F, R>(self, f: F) -> R
|
pub async fn run<F, R>(&self, f: F) -> R
|
||||||
where
|
where
|
||||||
F: FnOnce(&mut #conn_type) -> R + Send + 'static,
|
F: FnOnce(&mut #conn_type) -> R + Send + 'static,
|
||||||
R: Send + 'static,
|
R: Send + 'static,
|
||||||
{
|
{
|
||||||
self.0.run(f).await
|
self.0.run(f).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Asynchronously acquires another connection from the connection pool.
|
|
||||||
pub async fn clone(&mut self) -> ::std::result::Result<Self, ()> {
|
|
||||||
self.0.clone().await.map(Self)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[::rocket::async_trait]
|
#[::rocket::async_trait]
|
||||||
|
|
|
@ -1,21 +1,21 @@
|
||||||
error[E0277]: the trait bound `Unknown: rocket_contrib::databases::Poolable` is not satisfied
|
error[E0277]: the trait bound `Unknown: Poolable` is not satisfied
|
||||||
--> $DIR/database-types.rs:7:10
|
--> $DIR/database-types.rs:7:10
|
||||||
|
|
|
|
||||||
7 | struct A(Unknown);
|
7 | struct A(Unknown);
|
||||||
| ^^^^^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `Unknown`
|
| ^^^^^^^ the trait `Poolable` is not implemented for `Unknown`
|
||||||
|
|
|
|
||||||
::: $WORKSPACE/contrib/lib/src/databases.rs:832:29
|
::: $WORKSPACE/contrib/lib/src/databases.rs
|
||||||
|
|
|
|
||||||
832 | pub struct Connection<K, C: Poolable> {
|
| pub struct Connection<K, C: Poolable> {
|
||||||
| -------- required by this bound in `rocket_contrib::databases::Connection`
|
| -------- required by this bound in `rocket_contrib::databases::Connection`
|
||||||
|
|
||||||
error[E0277]: the trait bound `std::vec::Vec<i32>: rocket_contrib::databases::Poolable` is not satisfied
|
error[E0277]: the trait bound `Vec<i32>: Poolable` is not satisfied
|
||||||
--> $DIR/database-types.rs:10:10
|
--> $DIR/database-types.rs:10:10
|
||||||
|
|
|
|
||||||
10 | struct B(Vec<i32>);
|
10 | struct B(Vec<i32>);
|
||||||
| ^^^^^^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `std::vec::Vec<i32>`
|
| ^^^^^^^^ the trait `Poolable` is not implemented for `Vec<i32>`
|
||||||
|
|
|
|
||||||
::: $WORKSPACE/contrib/lib/src/databases.rs:832:29
|
::: $WORKSPACE/contrib/lib/src/databases.rs
|
||||||
|
|
|
|
||||||
832 | pub struct Connection<K, C: Poolable> {
|
| pub struct Connection<K, C: Poolable> {
|
||||||
| -------- required by this bound in `rocket_contrib::databases::Connection`
|
| -------- required by this bound in `rocket_contrib::databases::Connection`
|
||||||
|
|
|
@ -191,6 +191,23 @@
|
||||||
//! database. This corresponds to the database name set as the database's
|
//! database. This corresponds to the database name set as the database's
|
||||||
//! configuration key.
|
//! configuration key.
|
||||||
//!
|
//!
|
||||||
|
//! The macro generates a [`FromRequest`] implementation for the decorated type,
|
||||||
|
//! allowing the type to be used as a request guard. This implementation
|
||||||
|
//! retrieves a connection from the database pool or fails with a
|
||||||
|
//! `Status::ServiceUnavailable` if connecting to the database times out.
|
||||||
|
//!
|
||||||
|
//! The macro will also generate two inherent methods on the decorated type:
|
||||||
|
//!
|
||||||
|
//! * `fn fairing() -> impl Fairing`
|
||||||
|
//!
|
||||||
|
//! Returns a fairing that initializes the associated database connection
|
||||||
|
//! pool.
|
||||||
|
//!
|
||||||
|
//! * `async fn get_one(&Cargo) -> Option<Self>`
|
||||||
|
//!
|
||||||
|
//! Retrieves a connection wrapper from the configured pool. Returns `Some`
|
||||||
|
//! as long as `Self::fairing()` has been attached.
|
||||||
|
//!
|
||||||
//! The attribute can only be applied to unit-like structs with one type. The
|
//! The attribute can only be applied to unit-like structs with one type. The
|
||||||
//! internal type of the structure must implement [`Poolable`].
|
//! internal type of the structure must implement [`Poolable`].
|
||||||
//!
|
//!
|
||||||
|
@ -221,23 +238,6 @@
|
||||||
//! # }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! The macro generates a [`FromRequest`] implementation for the decorated type,
|
|
||||||
//! allowing the type to be used as a request guard. This implementation
|
|
||||||
//! retrieves a connection from the database pool or fails with a
|
|
||||||
//! `Status::ServiceUnavailable` if connecting to the database times out.
|
|
||||||
//!
|
|
||||||
//! The macro will also generate two inherent methods on the decorated type:
|
|
||||||
//!
|
|
||||||
//! * `fn fairing() -> impl Fairing`
|
|
||||||
//!
|
|
||||||
//! Returns a fairing that initializes the associated database connection
|
|
||||||
//! pool.
|
|
||||||
//!
|
|
||||||
//! * `async fn get_one(&Cargo) -> Option<Self>`
|
|
||||||
//!
|
|
||||||
//! Retrieves a connection wrapper from the configured pool. Returns `Some`
|
|
||||||
//! as long as `Self::fairing()` has been attached.
|
|
||||||
//!
|
|
||||||
//! The fairing returned from the generated `fairing()` method _must_ be
|
//! The fairing returned from the generated `fairing()` method _must_ be
|
||||||
//! attached for the request guard implementation to succeed. Putting the pieces
|
//! attached for the request guard implementation to succeed. Putting the pieces
|
||||||
//! together, a use of the `#[database]` attribute looks as follows:
|
//! together, a use of the `#[database]` attribute looks as follows:
|
||||||
|
@ -322,37 +322,6 @@
|
||||||
//! # }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! `run()` takes the connection by value. To make multiple calls to run,
|
|
||||||
//! obtain a second connection first with `clone()`:
|
|
||||||
//!
|
|
||||||
//! ```rust
|
|
||||||
//! # #[macro_use] extern crate rocket;
|
|
||||||
//! # #[macro_use] extern crate rocket_contrib;
|
|
||||||
//! #
|
|
||||||
//! # #[cfg(feature = "diesel_sqlite_pool")]
|
|
||||||
//! # mod test {
|
|
||||||
//! # use rocket_contrib::databases::diesel;
|
|
||||||
//! # type Data = ();
|
|
||||||
//! #[database("my_db")]
|
|
||||||
//! struct MyDatabase(diesel::SqliteConnection);
|
|
||||||
//!
|
|
||||||
//! fn load_from_db(conn: &diesel::SqliteConnection) -> Data {
|
|
||||||
//! // Do something with connection, return some data.
|
|
||||||
//! # ()
|
|
||||||
//! }
|
|
||||||
//!
|
|
||||||
//! #[get("/")]
|
|
||||||
//! async fn my_handler(mut conn: MyDatabase) -> Data {
|
|
||||||
//! let cloned = conn.clone().await.expect("");
|
|
||||||
//! cloned.run(|c| load_from_db(c)).await;
|
|
||||||
//!
|
|
||||||
//! // Do something else
|
|
||||||
//! conn.run(|c| load_from_db(c)).await;
|
|
||||||
//! }
|
|
||||||
//! # }
|
|
||||||
//! ```
|
|
||||||
//!
|
|
||||||
//!
|
|
||||||
//! # Database Support
|
//! # Database Support
|
||||||
//!
|
//!
|
||||||
//! Built-in support is provided for many popular databases and drivers. Support
|
//! Built-in support is provided for many popular databases and drivers. Support
|
||||||
|
@ -418,10 +387,11 @@ use std::sync::Arc;
|
||||||
|
|
||||||
use rocket::config::{self, Value};
|
use rocket::config::{self, Value};
|
||||||
use rocket::fairing::{AdHoc, Fairing};
|
use rocket::fairing::{AdHoc, Fairing};
|
||||||
|
use rocket::request::{Request, Outcome, FromRequest};
|
||||||
|
use rocket::outcome::IntoOutcome;
|
||||||
use rocket::http::Status;
|
use rocket::http::Status;
|
||||||
use rocket::request::Outcome;
|
|
||||||
|
|
||||||
use rocket::tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
use rocket::tokio::sync::{OwnedSemaphorePermit, Semaphore, Mutex};
|
||||||
|
|
||||||
use self::r2d2::ManageConnection;
|
use self::r2d2::ManageConnection;
|
||||||
|
|
||||||
|
@ -827,29 +797,25 @@ pub struct ConnectionPool<K, C: Poolable> {
|
||||||
/// types are properly checked.
|
/// types are properly checked.
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
pub struct Connection<K, C: Poolable> {
|
pub struct Connection<K, C: Poolable> {
|
||||||
pool: ConnectionPool<K, C>,
|
connection: Arc<Mutex<Option<r2d2::PooledConnection<C::Manager>>>>,
|
||||||
connection: Option<r2d2::PooledConnection<C::Manager>>,
|
permit: Option<OwnedSemaphorePermit>,
|
||||||
_permit: Option<OwnedSemaphorePermit>,
|
|
||||||
_marker: PhantomData<fn() -> K>,
|
_marker: PhantomData<fn() -> K>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// A wrapper around spawn_blocking that propagates panics to the calling code
|
// A wrapper around spawn_blocking that propagates panics to the calling code.
|
||||||
async fn run_blocking<F, R>(job: F) -> R
|
async fn run_blocking<F, R>(job: F) -> R
|
||||||
where
|
where F: FnOnce() -> R + Send + 'static, R: Send + 'static,
|
||||||
F: FnOnce() -> R + Send + 'static,
|
|
||||||
R: Send + 'static,
|
|
||||||
{
|
{
|
||||||
match tokio::task::spawn_blocking(job).await {
|
match tokio::task::spawn_blocking(job).await {
|
||||||
Ok(ret) => ret,
|
Ok(ret) => ret,
|
||||||
Err(e) => match e.try_into_panic() {
|
Err(e) => match e.try_into_panic() {
|
||||||
Ok(panic) => std::panic::resume_unwind(panic),
|
Ok(panic) => std::panic::resume_unwind(panic),
|
||||||
Err(_) => unreachable!("spawn_blocking tasks are never canceled"),
|
Err(_) => unreachable!("spawn_blocking tasks are never cancelled"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
|
impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
|
||||||
#[inline]
|
|
||||||
pub fn fairing(fairing_name: &'static str, config_name: &'static str) -> impl Fairing {
|
pub fn fairing(fairing_name: &'static str, config_name: &'static str) -> impl Fairing {
|
||||||
AdHoc::on_attach(fairing_name, move |mut rocket| async move {
|
AdHoc::on_attach(fairing_name, move |mut rocket| async move {
|
||||||
let config = database_config(config_name, rocket.config().await);
|
let config = database_config(config_name, rocket.config().await);
|
||||||
|
@ -881,9 +847,9 @@ impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get(&self) -> Result<Connection<K, C>, ()> {
|
async fn get(&self) -> Result<Connection<K, C>, ()> {
|
||||||
// TODO: timeout duration
|
// TODO: Make timeout configurable.
|
||||||
let permit = match tokio::time::timeout(
|
let permit = match tokio::time::timeout(
|
||||||
self.pool.connection_timeout(),
|
std::time::Duration::from_secs(5),
|
||||||
self.semaphore.clone().acquire_owned()
|
self.semaphore.clone().acquire_owned()
|
||||||
).await {
|
).await {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
|
@ -893,18 +859,14 @@ impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO: Make timeout configurable.
|
||||||
let pool = self.pool.clone();
|
let pool = self.pool.clone();
|
||||||
|
match run_blocking(move || pool.get_timeout(std::time::Duration::from_secs(5))).await {
|
||||||
// TODO: timeout duration
|
Ok(c) => Ok(Connection {
|
||||||
match run_blocking(move || pool.get_timeout(std::time::Duration::from_secs(0))).await {
|
connection: Arc::new(Mutex::new(Some(c))),
|
||||||
Ok(c) => {
|
permit: Some(permit),
|
||||||
Ok(Connection {
|
_marker: PhantomData,
|
||||||
pool: self.clone(),
|
}),
|
||||||
connection: Some(c),
|
|
||||||
_permit: Some(permit),
|
|
||||||
_marker: PhantomData,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error_!("Failed to get a database connection: {}", e);
|
error_!("Failed to get a database connection: {}", e);
|
||||||
Err(())
|
Err(())
|
||||||
|
@ -924,64 +886,49 @@ impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<K, C: Poolable> Clone for ConnectionPool<K, C> {
|
|
||||||
fn clone(&self) -> Self {
|
|
||||||
Self {
|
|
||||||
pool: self.pool.clone(),
|
|
||||||
semaphore: self.semaphore.clone(),
|
|
||||||
_marker: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<K: 'static, C: Poolable> Connection<K, C> {
|
impl<K: 'static, C: Poolable> Connection<K, C> {
|
||||||
#[inline]
|
#[inline]
|
||||||
pub async fn run<F, R>(self, f: F) -> R
|
pub async fn run<F, R>(&self, f: F) -> R
|
||||||
where F: FnOnce(&mut C) -> R + Send + 'static,
|
where F: FnOnce(&mut C) -> R + Send + 'static,
|
||||||
R: Send + 'static,
|
R: Send + 'static,
|
||||||
{
|
{
|
||||||
|
let mut connection = self.connection.clone().lock_owned().await;
|
||||||
run_blocking(move || {
|
run_blocking(move || {
|
||||||
// 'self' contains a semaphore permit that should be held throughout
|
let conn = connection.as_mut()
|
||||||
// this entire spawned task. Explicitly move it, so that a
|
|
||||||
// refactoring won't accidentally release the permit too early.
|
|
||||||
let mut this: Self = self;
|
|
||||||
|
|
||||||
let mut conn = this.connection.take()
|
|
||||||
.expect("internal invariant broken: self.connection is Some");
|
.expect("internal invariant broken: self.connection is Some");
|
||||||
|
f(conn)
|
||||||
f(&mut conn)
|
|
||||||
}).await
|
}).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub async fn clone(&mut self) -> Result<Self, ()> {
|
|
||||||
self.pool.get().await
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<K, C: Poolable> Drop for Connection<K, C> {
|
impl<K, C: Poolable> Drop for Connection<K, C> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
if let Some(conn) = self.connection.take() {
|
let connection = self.connection.clone();
|
||||||
tokio::task::spawn_blocking(|| drop(conn));
|
let permit = self.permit.take();
|
||||||
}
|
tokio::spawn(async move {
|
||||||
|
let mut connection = connection.lock_owned().await;
|
||||||
|
tokio::task::spawn_blocking(move || {
|
||||||
|
if let Some(conn) = connection.take() {
|
||||||
|
drop(conn);
|
||||||
|
}
|
||||||
|
// NB: Explicitly dropping the permit here so that it's only
|
||||||
|
// released after the connection is.
|
||||||
|
drop(permit);
|
||||||
|
})
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[rocket::async_trait]
|
#[rocket::async_trait]
|
||||||
impl<'a, 'r, K: 'static, C: Poolable> rocket::request::FromRequest<'a, 'r> for Connection<K, C> {
|
impl<'a, 'r, K: 'static, C: Poolable> FromRequest<'a, 'r> for Connection<K, C> {
|
||||||
type Error = ();
|
type Error = ();
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
async fn from_request(request: &'a rocket::request::Request<'r>) -> Outcome<Self, ()> {
|
async fn from_request(request: &'a Request<'r>) -> Outcome<Self, ()> {
|
||||||
match request.managed_state::<ConnectionPool<K, C>>() {
|
match request.managed_state::<ConnectionPool<K, C>>() {
|
||||||
Some(inner) => {
|
Some(c) => c.get().await.into_outcome(Status::ServiceUnavailable),
|
||||||
match inner.get().await {
|
|
||||||
Ok(c) => Outcome::Success(c),
|
|
||||||
Err(()) => Outcome::Failure((Status::ServiceUnavailable, ())),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => {
|
None => {
|
||||||
error_!("Database fairing was not attached for {}", std::any::type_name::<K>());
|
error_!("Missing database fairing for `{}`", std::any::type_name::<K>());
|
||||||
Outcome::Failure((Status::InternalServerError, ()))
|
Outcome::Failure((Status::InternalServerError, ()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,8 @@
|
||||||
mod task;
|
mod task;
|
||||||
#[cfg(test)] mod tests;
|
#[cfg(test)] mod tests;
|
||||||
|
|
||||||
|
use std::fmt::Display;
|
||||||
|
|
||||||
use rocket::Rocket;
|
use rocket::Rocket;
|
||||||
use rocket::fairing::AdHoc;
|
use rocket::fairing::AdHoc;
|
||||||
use rocket::request::{Form, FlashMessage};
|
use rocket::request::{Form, FlashMessage};
|
||||||
|
@ -31,20 +33,20 @@ struct Context {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Context {
|
impl Context {
|
||||||
pub async fn err(conn: DbConn, msg: String) -> Context {
|
pub async fn err<M: Display>(conn: &DbConn, msg: M) -> Context {
|
||||||
Context {
|
Context {
|
||||||
msg: Some(("error".to_string(), msg)),
|
msg: Some(("error".into(), msg.to_string())),
|
||||||
tasks: Task::all(conn).await.unwrap_or_default()
|
tasks: Task::all(conn).await.unwrap_or_default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn raw(conn: DbConn, msg: Option<(String, String)>) -> Context {
|
pub async fn raw(conn: &DbConn, msg: Option<(String, String)>) -> Context {
|
||||||
match Task::all(conn).await {
|
match Task::all(conn).await {
|
||||||
Ok(tasks) => Context { msg, tasks },
|
Ok(tasks) => Context { msg, tasks },
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error_!("DB Task::all() error: {}", e);
|
error_!("DB Task::all() error: {}", e);
|
||||||
Context {
|
Context {
|
||||||
msg: Some(("error".to_string(), "Couldn't access the task database.".to_string())),
|
msg: Some(("error".into(), "Fail to access database.".into())),
|
||||||
tasks: vec![]
|
tasks: vec![]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -57,7 +59,7 @@ async fn new(todo_form: Form<Todo>, conn: DbConn) -> Flash<Redirect> {
|
||||||
let todo = todo_form.into_inner();
|
let todo = todo_form.into_inner();
|
||||||
if todo.description.is_empty() {
|
if todo.description.is_empty() {
|
||||||
Flash::error(Redirect::to("/"), "Description cannot be empty.")
|
Flash::error(Redirect::to("/"), "Description cannot be empty.")
|
||||||
} else if let Err(e) = Task::insert(todo, conn).await {
|
} else if let Err(e) = Task::insert(todo, &conn).await {
|
||||||
error_!("DB insertion error: {}", e);
|
error_!("DB insertion error: {}", e);
|
||||||
Flash::error(Redirect::to("/"), "Todo could not be inserted due an internal error.")
|
Flash::error(Redirect::to("/"), "Todo could not be inserted due an internal error.")
|
||||||
} else {
|
} else {
|
||||||
|
@ -66,27 +68,23 @@ async fn new(todo_form: Form<Todo>, conn: DbConn) -> Flash<Redirect> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[put("/<id>")]
|
#[put("/<id>")]
|
||||||
async fn toggle(id: i32, mut conn: DbConn) -> Result<Redirect, Template> {
|
async fn toggle(id: i32, conn: DbConn) -> Result<Redirect, Template> {
|
||||||
// TODO
|
match Task::toggle_with_id(id, &conn).await {
|
||||||
let conn2 = conn.clone().await.unwrap();
|
|
||||||
match Task::toggle_with_id(id, conn).await {
|
|
||||||
Ok(_) => Ok(Redirect::to("/")),
|
Ok(_) => Ok(Redirect::to("/")),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error_!("DB toggle({}) error: {}", id, e);
|
error_!("DB toggle({}) error: {}", id, e);
|
||||||
Err(Template::render("index", Context::err(conn2, "Failed to toggle task.".to_string()).await))
|
Err(Template::render("index", Context::err(&conn, "Failed to toggle task.").await))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[delete("/<id>")]
|
#[delete("/<id>")]
|
||||||
async fn delete(id: i32, mut conn: DbConn) -> Result<Flash<Redirect>, Template> {
|
async fn delete(id: i32, conn: DbConn) -> Result<Flash<Redirect>, Template> {
|
||||||
// TODO
|
match Task::delete_with_id(id, &conn).await {
|
||||||
let conn2 = conn.clone().await.unwrap();
|
|
||||||
match Task::delete_with_id(id, conn).await {
|
|
||||||
Ok(_) => Ok(Flash::success(Redirect::to("/"), "Todo was deleted.")),
|
Ok(_) => Ok(Flash::success(Redirect::to("/"), "Todo was deleted.")),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error_!("DB deletion({}) error: {}", id, e);
|
error_!("DB deletion({}) error: {}", id, e);
|
||||||
Err(Template::render("index", Context::err(conn2, "Failed to delete task.".to_string()).await))
|
Err(Template::render("index", Context::err(&conn, "Failed to delete task.").await))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -94,20 +92,19 @@ async fn delete(id: i32, mut conn: DbConn) -> Result<Flash<Redirect>, Template>
|
||||||
#[get("/")]
|
#[get("/")]
|
||||||
async fn index(msg: Option<FlashMessage<'_, '_>>, conn: DbConn) -> Template {
|
async fn index(msg: Option<FlashMessage<'_, '_>>, conn: DbConn) -> Template {
|
||||||
let msg = msg.map(|m| (m.name().to_string(), m.msg().to_string()));
|
let msg = msg.map(|m| (m.name().to_string(), m.msg().to_string()));
|
||||||
Template::render("index", Context::raw(conn, msg).await)
|
Template::render("index", Context::raw(&conn, msg).await)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run_db_migrations(mut rocket: Rocket) -> Result<Rocket, Rocket> {
|
async fn run_db_migrations(mut rocket: Rocket) -> Result<Rocket, Rocket> {
|
||||||
let conn = DbConn::get_one(rocket.inspect().await).await.expect("database connection");
|
DbConn::get_one(rocket.inspect().await).await
|
||||||
conn.run(|c| {
|
.expect("database connection")
|
||||||
match embedded_migrations::run(c) {
|
.run(|c| match embedded_migrations::run(c) {
|
||||||
Ok(()) => Ok(rocket),
|
Ok(()) => Ok(rocket),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to run database migrations: {:?}", e);
|
error!("Failed to run database migrations: {:?}", e);
|
||||||
Err(rocket)
|
Err(rocket)
|
||||||
}
|
}
|
||||||
}
|
}).await
|
||||||
}).await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[launch]
|
#[launch]
|
||||||
|
|
|
@ -29,14 +29,14 @@ pub struct Todo {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Task {
|
impl Task {
|
||||||
pub async fn all(conn: DbConn) -> QueryResult<Vec<Task>> {
|
pub async fn all(conn: &DbConn) -> QueryResult<Vec<Task>> {
|
||||||
conn.run(|c| {
|
conn.run(|c| {
|
||||||
all_tasks.order(tasks::id.desc()).load::<Task>(c)
|
all_tasks.order(tasks::id.desc()).load::<Task>(c)
|
||||||
}).await
|
}).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the number of affected rows: 1.
|
/// Returns the number of affected rows: 1.
|
||||||
pub async fn insert(todo: Todo, conn: DbConn) -> QueryResult<usize> {
|
pub async fn insert(todo: Todo, conn: &DbConn) -> QueryResult<usize> {
|
||||||
conn.run(|c| {
|
conn.run(|c| {
|
||||||
let t = Task { id: None, description: todo.description, completed: false };
|
let t = Task { id: None, description: todo.description, completed: false };
|
||||||
diesel::insert_into(tasks::table).values(&t).execute(c)
|
diesel::insert_into(tasks::table).values(&t).execute(c)
|
||||||
|
@ -44,7 +44,7 @@ impl Task {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the number of affected rows: 1.
|
/// Returns the number of affected rows: 1.
|
||||||
pub async fn toggle_with_id(id: i32, conn: DbConn) -> QueryResult<usize> {
|
pub async fn toggle_with_id(id: i32, conn: &DbConn) -> QueryResult<usize> {
|
||||||
conn.run(move |c| {
|
conn.run(move |c| {
|
||||||
let task = all_tasks.find(id).get_result::<Task>(c)?;
|
let task = all_tasks.find(id).get_result::<Task>(c)?;
|
||||||
let new_status = !task.completed;
|
let new_status = !task.completed;
|
||||||
|
@ -54,13 +54,13 @@ impl Task {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the number of affected rows: 1.
|
/// Returns the number of affected rows: 1.
|
||||||
pub async fn delete_with_id(id: i32, conn: DbConn) -> QueryResult<usize> {
|
pub async fn delete_with_id(id: i32, conn: &DbConn) -> QueryResult<usize> {
|
||||||
conn.run(move |c| diesel::delete(all_tasks.find(id)).execute(c)).await
|
conn.run(move |c| diesel::delete(all_tasks.find(id)).execute(c)).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the number of affected rows.
|
/// Returns the number of affected rows.
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub async fn delete_all(conn: DbConn) -> QueryResult<usize> {
|
pub async fn delete_all(conn: &DbConn) -> QueryResult<usize> {
|
||||||
conn.run(|c| diesel::delete(all_tasks).execute(c)).await
|
conn.run(|c| diesel::delete(all_tasks).execute(c)).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,9 +19,8 @@ macro_rules! run_test {
|
||||||
let rocket = super::rocket();
|
let rocket = super::rocket();
|
||||||
let $client = Client::new(rocket).await.expect("Rocket client");
|
let $client = Client::new(rocket).await.expect("Rocket client");
|
||||||
let db = super::DbConn::get_one($client.cargo()).await;
|
let db = super::DbConn::get_one($client.cargo()).await;
|
||||||
let mut $conn = db.expect("failed to get database connection for testing");
|
let $conn = db.expect("failed to get database connection for testing");
|
||||||
let delete_conn = $conn.clone().await.expect("failed to get a second database connection for testing");
|
Task::delete_all(&$conn).await.expect("failed to delete all tasks for testing");
|
||||||
Task::delete_all(delete_conn).await.expect("failed to delete all tasks for testing");
|
|
||||||
|
|
||||||
$block
|
$block
|
||||||
})
|
})
|
||||||
|
@ -32,7 +31,7 @@ macro_rules! run_test {
|
||||||
fn test_insertion_deletion() {
|
fn test_insertion_deletion() {
|
||||||
run_test!(|client, conn| {
|
run_test!(|client, conn| {
|
||||||
// Get the tasks before making changes.
|
// Get the tasks before making changes.
|
||||||
let init_tasks = Task::all(conn.clone().await.unwrap()).await.unwrap();
|
let init_tasks = Task::all(&conn).await.unwrap();
|
||||||
|
|
||||||
// Issue a request to insert a new task.
|
// Issue a request to insert a new task.
|
||||||
client.post("/todo")
|
client.post("/todo")
|
||||||
|
@ -42,7 +41,7 @@ fn test_insertion_deletion() {
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Ensure we have one more task in the database.
|
// Ensure we have one more task in the database.
|
||||||
let new_tasks = Task::all(conn.clone().await.unwrap()).await.unwrap();
|
let new_tasks = Task::all(&conn).await.unwrap();
|
||||||
assert_eq!(new_tasks.len(), init_tasks.len() + 1);
|
assert_eq!(new_tasks.len(), init_tasks.len() + 1);
|
||||||
|
|
||||||
// Ensure the task is what we expect.
|
// Ensure the task is what we expect.
|
||||||
|
@ -54,7 +53,7 @@ fn test_insertion_deletion() {
|
||||||
client.delete(format!("/todo/{}", id)).dispatch().await;
|
client.delete(format!("/todo/{}", id)).dispatch().await;
|
||||||
|
|
||||||
// Ensure it's gone.
|
// Ensure it's gone.
|
||||||
let final_tasks = Task::all(conn).await.unwrap();
|
let final_tasks = Task::all(&conn).await.unwrap();
|
||||||
assert_eq!(final_tasks.len(), init_tasks.len());
|
assert_eq!(final_tasks.len(), init_tasks.len());
|
||||||
if final_tasks.len() > 0 {
|
if final_tasks.len() > 0 {
|
||||||
assert_ne!(final_tasks[0].description, "My first task");
|
assert_ne!(final_tasks[0].description, "My first task");
|
||||||
|
@ -72,16 +71,16 @@ fn test_toggle() {
|
||||||
.dispatch()
|
.dispatch()
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let task = Task::all(conn.clone().await.unwrap()).await.unwrap()[0].clone();
|
let task = Task::all(&conn).await.unwrap()[0].clone();
|
||||||
assert_eq!(task.completed, false);
|
assert_eq!(task.completed, false);
|
||||||
|
|
||||||
// Issue a request to toggle the task; ensure it is completed.
|
// Issue a request to toggle the task; ensure it is completed.
|
||||||
client.put(format!("/todo/{}", task.id.unwrap())).dispatch().await;
|
client.put(format!("/todo/{}", task.id.unwrap())).dispatch().await;
|
||||||
assert_eq!(Task::all(conn.clone().await.unwrap()).await.unwrap()[0].completed, true);
|
assert_eq!(Task::all(&conn).await.unwrap()[0].completed, true);
|
||||||
|
|
||||||
// Issue a request to toggle the task; ensure it's not completed again.
|
// Issue a request to toggle the task; ensure it's not completed again.
|
||||||
client.put(format!("/todo/{}", task.id.unwrap())).dispatch().await;
|
client.put(format!("/todo/{}", task.id.unwrap())).dispatch().await;
|
||||||
assert_eq!(Task::all(conn).await.unwrap()[0].completed, false);
|
assert_eq!(Task::all(&conn).await.unwrap()[0].completed, false);
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,7 +90,7 @@ fn test_many_insertions() {
|
||||||
|
|
||||||
run_test!(|client, conn| {
|
run_test!(|client, conn| {
|
||||||
// Get the number of tasks initially.
|
// Get the number of tasks initially.
|
||||||
let init_num = Task::all(conn.clone().await.unwrap()).await.unwrap().len();
|
let init_num = Task::all(&conn).await.unwrap().len();
|
||||||
let mut descs = Vec::new();
|
let mut descs = Vec::new();
|
||||||
|
|
||||||
for i in 0..ITER {
|
for i in 0..ITER {
|
||||||
|
@ -107,7 +106,7 @@ fn test_many_insertions() {
|
||||||
descs.insert(0, desc);
|
descs.insert(0, desc);
|
||||||
|
|
||||||
// Ensure the task was inserted properly and all other tasks remain.
|
// Ensure the task was inserted properly and all other tasks remain.
|
||||||
let tasks = Task::all(conn.clone().await.unwrap()).await.unwrap();
|
let tasks = Task::all(&conn).await.unwrap();
|
||||||
assert_eq!(tasks.len(), init_num + i + 1);
|
assert_eq!(tasks.len(), init_num + i + 1);
|
||||||
|
|
||||||
for j in 0..i {
|
for j in 0..i {
|
||||||
|
|
Loading…
Reference in New Issue