mirror of https://github.com/rwf2/Rocket.git
Use 'spawn_blocking' in '#[database]'.
The connection guard type generated by `#[database]` no longer implements `Deref` and `DerefMut`. Instead, it provides an `async fn run()` that gives access to the underlying connection on a closure run through `spawn_blocking()`. Additionally moves most of the implementation of `#[database]` out of generated code and into library code for better type-checking.
This commit is contained in:
parent
f7eacb6a65
commit
bc8c5b9ee2
|
@ -64,23 +64,16 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result<TokenStrea
|
||||||
let name = &invocation.db_name;
|
let name = &invocation.db_name;
|
||||||
let guard_type = &invocation.type_name;
|
let guard_type = &invocation.type_name;
|
||||||
let vis = &invocation.visibility;
|
let vis = &invocation.visibility;
|
||||||
let pool_type = Ident::new(&format!("{}Pool", guard_type), guard_type.span());
|
|
||||||
let fairing_name = format!("'{}' Database Pool", name);
|
let fairing_name = format!("'{}' Database Pool", name);
|
||||||
let span = conn_type.span().into();
|
let span = conn_type.span().into();
|
||||||
|
|
||||||
// A few useful paths.
|
// A few useful paths.
|
||||||
let databases = quote_spanned!(span => ::rocket_contrib::databases);
|
let databases = quote_spanned!(span => ::rocket_contrib::databases);
|
||||||
let Poolable = quote_spanned!(span => #databases::Poolable);
|
|
||||||
let r2d2 = quote_spanned!(span => #databases::r2d2);
|
|
||||||
let spawn_blocking = quote_spanned!(span => #databases::spawn_blocking);
|
|
||||||
let request = quote!(::rocket::request);
|
let request = quote!(::rocket::request);
|
||||||
|
|
||||||
let generated_types = quote_spanned! { span =>
|
let generated_types = quote_spanned! { span =>
|
||||||
/// The request guard type.
|
/// The request guard type.
|
||||||
#vis struct #guard_type(pub #r2d2::PooledConnection<<#conn_type as #Poolable>::Manager>);
|
#vis struct #guard_type(#databases::Connection<Self, #conn_type>);
|
||||||
|
|
||||||
/// The pool type.
|
|
||||||
#vis struct #pool_type(#r2d2::Pool<<#conn_type as #Poolable>::Manager>);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(quote! {
|
Ok(quote! {
|
||||||
|
@ -90,53 +83,31 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result<TokenStrea
|
||||||
/// Returns a fairing that initializes the associated database
|
/// Returns a fairing that initializes the associated database
|
||||||
/// connection pool.
|
/// connection pool.
|
||||||
pub fn fairing() -> impl ::rocket::fairing::Fairing {
|
pub fn fairing() -> impl ::rocket::fairing::Fairing {
|
||||||
use #databases::Poolable;
|
<#databases::ConnectionPool<Self, #conn_type>>::fairing(#fairing_name, #name)
|
||||||
|
|
||||||
::rocket::fairing::AdHoc::on_attach(#fairing_name, |mut rocket| async {
|
|
||||||
let pool = #databases::database_config(#name, rocket.config().await)
|
|
||||||
.map(<#conn_type>::pool);
|
|
||||||
|
|
||||||
match pool {
|
|
||||||
Ok(Ok(p)) => Ok(rocket.manage(#pool_type(p))),
|
|
||||||
Err(config_error) => {
|
|
||||||
::rocket::logger::error(
|
|
||||||
&format!("Database configuration failure: '{}'", #name));
|
|
||||||
::rocket::logger::error_(&format!("{}", config_error));
|
|
||||||
Err(rocket)
|
|
||||||
},
|
|
||||||
Ok(Err(pool_error)) => {
|
|
||||||
::rocket::logger::error(
|
|
||||||
&format!("Failed to initialize pool for '{}'", #name));
|
|
||||||
::rocket::logger::error_(&format!("{:?}", pool_error));
|
|
||||||
Err(rocket)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Retrieves a connection of type `Self` from the `rocket`
|
/// Retrieves a connection of type `Self` from the `rocket`
|
||||||
/// instance. Returns `Some` as long as `Self::fairing()` has been
|
/// instance. Returns `Some` as long as `Self::fairing()` has been
|
||||||
/// attached and there is at least one connection in the pool.
|
/// attached.
|
||||||
pub fn get_one(cargo: &::rocket::Cargo) -> Option<Self> {
|
pub async fn get_one(cargo: &::rocket::Cargo) -> Option<Self> {
|
||||||
cargo.state::<#pool_type>()
|
<#databases::ConnectionPool<Self, #conn_type>>::get_one(cargo).await.map(Self)
|
||||||
.and_then(|pool| pool.0.get().ok())
|
|
||||||
.map(#guard_type)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl ::std::ops::Deref for #guard_type {
|
/// Runs the provided closure on a thread from a threadpool. The
|
||||||
type Target = #conn_type;
|
/// closure will be passed an `&mut r2d2::PooledConnection`.
|
||||||
|
/// `.await`ing the return value of this function yields the value
|
||||||
#[inline(always)]
|
/// returned by the closure.
|
||||||
fn deref(&self) -> &Self::Target {
|
pub async fn run<F, R>(self, f: F) -> R
|
||||||
&self.0
|
where
|
||||||
|
F: FnOnce(&mut #conn_type) -> R + Send + 'static,
|
||||||
|
R: Send + 'static,
|
||||||
|
{
|
||||||
|
self.0.run(f).await
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl ::std::ops::DerefMut for #guard_type {
|
/// Asynchronously acquires another connection from the connection pool.
|
||||||
#[inline(always)]
|
pub async fn clone(&mut self) -> ::std::result::Result<Self, ()> {
|
||||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
self.0.clone().await.map(Self)
|
||||||
&mut self.0
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -145,20 +116,8 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result<TokenStrea
|
||||||
type Error = ();
|
type Error = ();
|
||||||
|
|
||||||
async fn from_request(request: &'a #request::Request<'r>) -> #request::Outcome<Self, ()> {
|
async fn from_request(request: &'a #request::Request<'r>) -> #request::Outcome<Self, ()> {
|
||||||
use ::rocket::http::Status;
|
<#databases::Connection<Self, #conn_type>>::from_request(request).await.map(Self)
|
||||||
|
|
||||||
let guard = request.guard::<::rocket::State<'_, #pool_type>>();
|
|
||||||
let pool = ::rocket::try_outcome!(guard.await).0.clone();
|
|
||||||
|
|
||||||
#spawn_blocking(move || {
|
|
||||||
match pool.get() {
|
|
||||||
Ok(conn) => #request::Outcome::Success(#guard_type(conn)),
|
|
||||||
Err(_) => #request::Outcome::Failure((Status::ServiceUnavailable, ())),
|
|
||||||
}
|
|
||||||
}).await.expect("failed to spawn a blocking task to get a pooled connection")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO.async: What about spawn_blocking on drop?
|
|
||||||
}.into())
|
}.into())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,21 @@
|
||||||
error[E0277]: the trait bound `Unknown: Poolable` is not satisfied
|
error[E0277]: the trait bound `Unknown: rocket_contrib::databases::Poolable` is not satisfied
|
||||||
--> $DIR/database-types.rs:7:10
|
--> $DIR/database-types.rs:7:10
|
||||||
|
|
|
|
||||||
7 | struct A(Unknown);
|
7 | struct A(Unknown);
|
||||||
| ^^^^^^^ the trait `Poolable` is not implemented for `Unknown`
|
| ^^^^^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `Unknown`
|
||||||
|
|
|
||||||
|
::: $WORKSPACE/contrib/lib/src/databases.rs:832:29
|
||||||
|
|
|
||||||
|
832 | pub struct Connection<K, C: Poolable> {
|
||||||
|
| -------- required by this bound in `rocket_contrib::databases::Connection`
|
||||||
|
|
||||||
error[E0277]: the trait bound `Vec<i32>: Poolable` is not satisfied
|
error[E0277]: the trait bound `std::vec::Vec<i32>: rocket_contrib::databases::Poolable` is not satisfied
|
||||||
--> $DIR/database-types.rs:10:10
|
--> $DIR/database-types.rs:10:10
|
||||||
|
|
|
|
||||||
10 | struct B(Vec<i32>);
|
10 | struct B(Vec<i32>);
|
||||||
| ^^^^^^^^ the trait `Poolable` is not implemented for `Vec<i32>`
|
| ^^^^^^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `std::vec::Vec<i32>`
|
||||||
|
|
|
||||||
|
::: $WORKSPACE/contrib/lib/src/databases.rs:832:29
|
||||||
|
|
|
||||||
|
832 | pub struct Connection<K, C: Poolable> {
|
||||||
|
| -------- required by this bound in `rocket_contrib::databases::Connection`
|
||||||
|
|
|
@ -1,11 +1,21 @@
|
||||||
error[E0277]: the trait bound `Unknown: rocket_contrib::databases::Poolable` is not satisfied
|
error[E0277]: the trait bound `Unknown: rocket_contrib::databases::Poolable` is not satisfied
|
||||||
--> $DIR/database-types.rs:7:10
|
--> $DIR/database-types.rs:7:10
|
||||||
|
|
|
|
||||||
7 | struct A(Unknown);
|
7 | struct A(Unknown);
|
||||||
| ^^^^^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `Unknown`
|
| ^^^^^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `Unknown`
|
||||||
|
|
|
||||||
|
::: $WORKSPACE/contrib/lib/src/databases.rs
|
||||||
|
|
|
||||||
|
| pub struct Connection<K, C: Poolable> {
|
||||||
|
| -------- required by this bound in `rocket_contrib::databases::Connection`
|
||||||
|
|
||||||
error[E0277]: the trait bound `std::vec::Vec<i32>: rocket_contrib::databases::Poolable` is not satisfied
|
error[E0277]: the trait bound `std::vec::Vec<i32>: rocket_contrib::databases::Poolable` is not satisfied
|
||||||
--> $DIR/database-types.rs:10:10
|
--> $DIR/database-types.rs:10:10
|
||||||
|
|
|
|
||||||
10 | struct B(Vec<i32>);
|
10 | struct B(Vec<i32>);
|
||||||
| ^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `std::vec::Vec<i32>`
|
| ^^^ the trait `rocket_contrib::databases::Poolable` is not implemented for `std::vec::Vec<i32>`
|
||||||
|
|
|
||||||
|
::: $WORKSPACE/contrib/lib/src/databases.rs
|
||||||
|
|
|
||||||
|
| pub struct Connection<K, C: Poolable> {
|
||||||
|
| -------- required by this bound in `rocket_contrib::databases::Connection`
|
||||||
|
|
|
@ -84,9 +84,9 @@
|
||||||
//! # type Result<T> = std::result::Result<T, ()>;
|
//! # type Result<T> = std::result::Result<T, ()>;
|
||||||
//! #
|
//! #
|
||||||
//! #[get("/logs/<id>")]
|
//! #[get("/logs/<id>")]
|
||||||
//! fn get_logs(conn: LogsDbConn, id: usize) -> Result<Logs> {
|
//! async fn get_logs(conn: LogsDbConn, id: usize) -> Result<Logs> {
|
||||||
//! # /*
|
//! # /*
|
||||||
//! Logs::by_id(&*conn, id)
|
//! conn.run(|c| Logs::by_id(c, id)).await
|
||||||
//! # */
|
//! # */
|
||||||
//! # Ok(())
|
//! # Ok(())
|
||||||
//! }
|
//! }
|
||||||
|
@ -224,9 +224,7 @@
|
||||||
//! The macro generates a [`FromRequest`] implementation for the decorated type,
|
//! The macro generates a [`FromRequest`] implementation for the decorated type,
|
||||||
//! allowing the type to be used as a request guard. This implementation
|
//! allowing the type to be used as a request guard. This implementation
|
||||||
//! retrieves a connection from the database pool or fails with a
|
//! retrieves a connection from the database pool or fails with a
|
||||||
//! `Status::ServiceUnavailable` if no connections are available. The macro also
|
//! `Status::ServiceUnavailable` if connecting to the database times out.
|
||||||
//! generates an implementation of the [`Deref`](std::ops::Deref) trait with
|
|
||||||
//! the internal `Poolable` type as the target.
|
|
||||||
//!
|
//!
|
||||||
//! The macro will also generate two inherent methods on the decorated type:
|
//! The macro will also generate two inherent methods on the decorated type:
|
||||||
//!
|
//!
|
||||||
|
@ -235,11 +233,10 @@
|
||||||
//! Returns a fairing that initializes the associated database connection
|
//! Returns a fairing that initializes the associated database connection
|
||||||
//! pool.
|
//! pool.
|
||||||
//!
|
//!
|
||||||
//! * `fn get_one(&Cargo) -> Option<Self>`
|
//! * `async fn get_one(&Cargo) -> Option<Self>`
|
||||||
//!
|
//!
|
||||||
//! Retrieves a connection from the configured pool. Returns `Some` as long
|
//! Retrieves a connection wrapper from the configured pool. Returns `Some`
|
||||||
//! as `Self::fairing()` has been attached and there is at least one
|
//! as long as `Self::fairing()` has been attached.
|
||||||
//! connection in the pool.
|
|
||||||
//!
|
//!
|
||||||
//! The fairing returned from the generated `fairing()` method _must_ be
|
//! The fairing returned from the generated `fairing()` method _must_ be
|
||||||
//! attached for the request guard implementation to succeed. Putting the pieces
|
//! attached for the request guard implementation to succeed. Putting the pieces
|
||||||
|
@ -280,8 +277,8 @@
|
||||||
//!
|
//!
|
||||||
//! ## Handlers
|
//! ## Handlers
|
||||||
//!
|
//!
|
||||||
//! Finally, simply use your type as a request guard in a handler to retrieve a
|
//! Finally, use your type as a request guard in a handler to retrieve a
|
||||||
//! connection to a given database:
|
//! connection wrapper for the database:
|
||||||
//!
|
//!
|
||||||
//! ```rust
|
//! ```rust
|
||||||
//! # #[macro_use] extern crate rocket;
|
//! # #[macro_use] extern crate rocket;
|
||||||
|
@ -300,8 +297,7 @@
|
||||||
//! # }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! The generated `Deref` implementation allows easy access to the inner
|
//! A connection can be retrieved and used with the `run()` method:
|
||||||
//! connection type:
|
|
||||||
//!
|
//!
|
||||||
//! ```rust
|
//! ```rust
|
||||||
//! # #[macro_use] extern crate rocket;
|
//! # #[macro_use] extern crate rocket;
|
||||||
|
@ -320,12 +316,43 @@
|
||||||
//! }
|
//! }
|
||||||
//!
|
//!
|
||||||
//! #[get("/")]
|
//! #[get("/")]
|
||||||
//! fn my_handler(conn: MyDatabase) -> Data {
|
//! async fn my_handler(mut conn: MyDatabase) -> Data {
|
||||||
//! load_from_db(&*conn)
|
//! conn.run(|c| load_from_db(c)).await
|
||||||
//! }
|
//! }
|
||||||
//! # }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
|
//! `run()` takes the connection by value. To make multiple calls to run,
|
||||||
|
//! obtain a second connection first with `clone()`:
|
||||||
|
//!
|
||||||
|
//! ```rust
|
||||||
|
//! # #[macro_use] extern crate rocket;
|
||||||
|
//! # #[macro_use] extern crate rocket_contrib;
|
||||||
|
//! #
|
||||||
|
//! # #[cfg(feature = "diesel_sqlite_pool")]
|
||||||
|
//! # mod test {
|
||||||
|
//! # use rocket_contrib::databases::diesel;
|
||||||
|
//! # type Data = ();
|
||||||
|
//! #[database("my_db")]
|
||||||
|
//! struct MyDatabase(diesel::SqliteConnection);
|
||||||
|
//!
|
||||||
|
//! fn load_from_db(conn: &diesel::SqliteConnection) -> Data {
|
||||||
|
//! // Do something with connection, return some data.
|
||||||
|
//! # ()
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! #[get("/")]
|
||||||
|
//! async fn my_handler(mut conn: MyDatabase) -> Data {
|
||||||
|
//! let cloned = conn.clone().await.expect("");
|
||||||
|
//! cloned.run(|c| load_from_db(c)).await;
|
||||||
|
//!
|
||||||
|
//! // Do something else
|
||||||
|
//! conn.run(|c| load_from_db(c)).await;
|
||||||
|
//! }
|
||||||
|
//! # }
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//!
|
||||||
//! # Database Support
|
//! # Database Support
|
||||||
//!
|
//!
|
||||||
//! Built-in support is provided for many popular databases and drivers. Support
|
//! Built-in support is provided for many popular databases and drivers. Support
|
||||||
|
@ -380,18 +407,21 @@
|
||||||
|
|
||||||
pub extern crate r2d2;
|
pub extern crate r2d2;
|
||||||
|
|
||||||
#[doc(hidden)]
|
|
||||||
pub use tokio::task::spawn_blocking;
|
|
||||||
|
|
||||||
#[cfg(any(feature = "diesel_sqlite_pool",
|
#[cfg(any(feature = "diesel_sqlite_pool",
|
||||||
feature = "diesel_postgres_pool",
|
feature = "diesel_postgres_pool",
|
||||||
feature = "diesel_mysql_pool"))]
|
feature = "diesel_mysql_pool"))]
|
||||||
pub extern crate diesel;
|
pub extern crate diesel;
|
||||||
|
|
||||||
use std::fmt::{self, Display, Formatter};
|
use std::fmt::{self, Display, Formatter};
|
||||||
use std::marker::{Send, Sized};
|
use std::marker::PhantomData;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use rocket::config::{self, Value};
|
use rocket::config::{self, Value};
|
||||||
|
use rocket::fairing::{AdHoc, Fairing};
|
||||||
|
use rocket::http::Status;
|
||||||
|
use rocket::request::Outcome;
|
||||||
|
|
||||||
|
use rocket::tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
||||||
|
|
||||||
use self::r2d2::ManageConnection;
|
use self::r2d2::ManageConnection;
|
||||||
|
|
||||||
|
@ -688,7 +718,7 @@ pub trait Poolable: Send + Sized + 'static {
|
||||||
type Manager: ManageConnection<Connection=Self>;
|
type Manager: ManageConnection<Connection=Self>;
|
||||||
/// The associated error type in the event that constructing the connection
|
/// The associated error type in the event that constructing the connection
|
||||||
/// manager and/or the connection pool fails.
|
/// manager and/or the connection pool fails.
|
||||||
type Error;
|
type Error: std::fmt::Debug;
|
||||||
|
|
||||||
/// Creates an `r2d2` connection pool for `Manager::Connection`, returning
|
/// Creates an `r2d2` connection pool for `Manager::Connection`, returning
|
||||||
/// the pool on success.
|
/// the pool on success.
|
||||||
|
@ -780,6 +810,184 @@ impl Poolable for memcache::Client {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Unstable internal details of generated code for the #[database] attribute.
|
||||||
|
///
|
||||||
|
/// This type is implemented here instead of in generated code to ensure all
|
||||||
|
/// types are properly checked.
|
||||||
|
#[doc(hidden)]
|
||||||
|
pub struct ConnectionPool<K, C: Poolable> {
|
||||||
|
pool: r2d2::Pool<C::Manager>,
|
||||||
|
semaphore: Arc<Semaphore>,
|
||||||
|
_marker: PhantomData<fn() -> K>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unstable internal details of generated code for the #[database] attribute.
|
||||||
|
///
|
||||||
|
/// This type is implemented here instead of in generated code to ensure all
|
||||||
|
/// types are properly checked.
|
||||||
|
#[doc(hidden)]
|
||||||
|
pub struct Connection<K, C: Poolable> {
|
||||||
|
pool: ConnectionPool<K, C>,
|
||||||
|
connection: Option<r2d2::PooledConnection<C::Manager>>,
|
||||||
|
_permit: Option<OwnedSemaphorePermit>,
|
||||||
|
_marker: PhantomData<fn() -> K>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// A wrapper around spawn_blocking that propagates panics to the calling code
|
||||||
|
async fn run_blocking<F, R>(job: F) -> R
|
||||||
|
where
|
||||||
|
F: FnOnce() -> R + Send + 'static,
|
||||||
|
R: Send + 'static,
|
||||||
|
{
|
||||||
|
match tokio::task::spawn_blocking(job).await {
|
||||||
|
Ok(ret) => ret,
|
||||||
|
Err(e) => match e.try_into_panic() {
|
||||||
|
Ok(panic) => std::panic::resume_unwind(panic),
|
||||||
|
Err(_) => unreachable!("spawn_blocking tasks are never canceled"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
|
||||||
|
#[inline]
|
||||||
|
pub fn fairing(fairing_name: &'static str, config_name: &'static str) -> impl Fairing {
|
||||||
|
AdHoc::on_attach(fairing_name, move |mut rocket| async move {
|
||||||
|
let config = database_config(config_name, rocket.config().await);
|
||||||
|
let pool = config.map(|c| (c.pool_size, C::pool(c)));
|
||||||
|
|
||||||
|
match pool {
|
||||||
|
Ok((size, Ok(pool))) => {
|
||||||
|
let managed = ConnectionPool::<K, C> {
|
||||||
|
pool,
|
||||||
|
semaphore: Arc::new(Semaphore::new(size as usize)),
|
||||||
|
_marker: PhantomData,
|
||||||
|
};
|
||||||
|
Ok(rocket.manage(managed))
|
||||||
|
},
|
||||||
|
Err(config_error) => {
|
||||||
|
rocket::logger::error(
|
||||||
|
&format!("Database configuration failure: '{}'", config_name));
|
||||||
|
rocket::logger::error_(&config_error.to_string());
|
||||||
|
Err(rocket)
|
||||||
|
},
|
||||||
|
Ok((_, Err(pool_error))) => {
|
||||||
|
rocket::logger::error(
|
||||||
|
&format!("Failed to initialize pool for '{}'", config_name));
|
||||||
|
rocket::logger::error_(&format!("{:?}", pool_error));
|
||||||
|
Err(rocket)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get(&self) -> Result<Connection<K, C>, ()> {
|
||||||
|
// TODO: timeout duration
|
||||||
|
let permit = match tokio::time::timeout(
|
||||||
|
self.pool.connection_timeout(),
|
||||||
|
self.semaphore.clone().acquire_owned()
|
||||||
|
).await {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(_) => {
|
||||||
|
error_!("Failed to get a database connection within the timeout.");
|
||||||
|
return Err(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let pool = self.pool.clone();
|
||||||
|
|
||||||
|
// TODO: timeout duration
|
||||||
|
match run_blocking(move || pool.get_timeout(std::time::Duration::from_secs(0))).await {
|
||||||
|
Ok(c) => {
|
||||||
|
Ok(Connection {
|
||||||
|
pool: self.clone(),
|
||||||
|
connection: Some(c),
|
||||||
|
_permit: Some(permit),
|
||||||
|
_marker: PhantomData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error_!("Failed to get a database connection: {}", e);
|
||||||
|
Err(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub async fn get_one(cargo: &rocket::Cargo) -> Option<Connection<K, C>> {
|
||||||
|
match cargo.state::<Self>() {
|
||||||
|
Some(pool) => pool.get().await.ok(),
|
||||||
|
None => {
|
||||||
|
error_!("Database fairing was not attached for {}", std::any::type_name::<K>());
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<K, C: Poolable> Clone for ConnectionPool<K, C> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
pool: self.pool.clone(),
|
||||||
|
semaphore: self.semaphore.clone(),
|
||||||
|
_marker: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<K: 'static, C: Poolable> Connection<K, C> {
|
||||||
|
#[inline]
|
||||||
|
pub async fn run<F, R>(self, f: F) -> R
|
||||||
|
where F: FnOnce(&mut C) -> R + Send + 'static,
|
||||||
|
R: Send + 'static,
|
||||||
|
{
|
||||||
|
run_blocking(move || {
|
||||||
|
// 'self' contains a semaphore permit that should be held throughout
|
||||||
|
// this entire spawned task. Explicitly move it, so that a
|
||||||
|
// refactoring won't accidentally release the permit too early.
|
||||||
|
let mut this: Self = self;
|
||||||
|
|
||||||
|
let mut conn = this.connection.take()
|
||||||
|
.expect("internal invariant broken: self.connection is Some");
|
||||||
|
|
||||||
|
f(&mut conn)
|
||||||
|
}).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub async fn clone(&mut self) -> Result<Self, ()> {
|
||||||
|
self.pool.get().await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<K, C: Poolable> Drop for Connection<K, C> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Some(conn) = self.connection.take() {
|
||||||
|
tokio::task::spawn_blocking(|| drop(conn));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
|
impl<'a, 'r, K: 'static, C: Poolable> rocket::request::FromRequest<'a, 'r> for Connection<K, C> {
|
||||||
|
type Error = ();
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
async fn from_request(request: &'a rocket::request::Request<'r>) -> Outcome<Self, ()> {
|
||||||
|
match request.managed_state::<ConnectionPool<K, C>>() {
|
||||||
|
Some(inner) => {
|
||||||
|
match inner.get().await {
|
||||||
|
Ok(c) => Outcome::Success(c),
|
||||||
|
Err(()) => Outcome::Failure((Status::ServiceUnavailable, ())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
error_!("Database fairing was not attached for {}", std::any::type_name::<K>());
|
||||||
|
Outcome::Failure((Status::InternalServerError, ()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
|
|
|
@ -21,40 +21,31 @@ mod rusqlite_integration_test {
|
||||||
#[database("test_db")]
|
#[database("test_db")]
|
||||||
struct SqliteDb(pub rusqlite::Connection);
|
struct SqliteDb(pub rusqlite::Connection);
|
||||||
|
|
||||||
|
// Test to ensure that multiple databases of the same type can be used
|
||||||
|
#[database("test_db_2")]
|
||||||
|
struct SqliteDb2(pub rusqlite::Connection);
|
||||||
|
|
||||||
#[rocket::async_test]
|
#[rocket::async_test]
|
||||||
async fn deref_mut_impl_present() {
|
async fn test_db() {
|
||||||
let mut test_db: Map<String, Value> = Map::new();
|
let mut test_db: Map<String, Value> = Map::new();
|
||||||
let mut test_db_opts: Map<String, Value> = Map::new();
|
let mut test_db_opts: Map<String, Value> = Map::new();
|
||||||
test_db_opts.insert("url".into(), Value::String(":memory:".into()));
|
test_db_opts.insert("url".into(), Value::String(":memory:".into()));
|
||||||
test_db.insert("test_db".into(), Value::Table(test_db_opts));
|
test_db.insert("test_db".into(), Value::Table(test_db_opts.clone()));
|
||||||
|
test_db.insert("test_db_2".into(), Value::Table(test_db_opts));
|
||||||
let config = Config::build(Environment::Development)
|
let config = Config::build(Environment::Development)
|
||||||
.extra("databases", Value::Table(test_db))
|
.extra("databases", Value::Table(test_db))
|
||||||
.finalize()
|
.finalize()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let mut rocket = rocket::custom(config).attach(SqliteDb::fairing());
|
let mut rocket = rocket::custom(config).attach(SqliteDb::fairing()).attach(SqliteDb2::fairing());
|
||||||
let mut conn = SqliteDb::get_one(rocket.inspect().await).expect("unable to get connection");
|
let conn = SqliteDb::get_one(rocket.inspect().await).await.expect("unable to get connection");
|
||||||
|
|
||||||
// Rusqlite's `transaction()` method takes `&mut self`; this tests the
|
// Rusqlite's `transaction()` method takes `&mut self`; this tests that
|
||||||
// presence of a `DerefMut` trait on the generated connection type.
|
// the &mut method can be called inside the closure passed to `run()`.
|
||||||
let tx = conn.transaction().unwrap();
|
conn.run(|conn| {
|
||||||
let _: i32 = tx.query_row("SELECT 1", &[] as &[&dyn ToSql], |row| row.get(0)).expect("get row");
|
let tx = conn.transaction().unwrap();
|
||||||
tx.commit().expect("committed transaction");
|
let _: i32 = tx.query_row("SELECT 1", &[] as &[&dyn ToSql], |row| row.get(0)).expect("get row");
|
||||||
}
|
tx.commit().expect("committed transaction");
|
||||||
|
}).await;
|
||||||
#[rocket::async_test]
|
|
||||||
async fn deref_impl_present() {
|
|
||||||
let mut test_db: Map<String, Value> = Map::new();
|
|
||||||
let mut test_db_opts: Map<String, Value> = Map::new();
|
|
||||||
test_db_opts.insert("url".into(), Value::String(":memory:".into()));
|
|
||||||
test_db.insert("test_db".into(), Value::Table(test_db_opts));
|
|
||||||
let config = Config::build(Environment::Development)
|
|
||||||
.extra("databases", Value::Table(test_db))
|
|
||||||
.finalize()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let mut rocket = rocket::custom(config).attach(SqliteDb::fairing());
|
|
||||||
let conn = SqliteDb::get_one(rocket.inspect().await).expect("unable to get connection");
|
|
||||||
let _: i32 = conn.query_row("SELECT 1", &[] as &[&dyn ToSql], |row| row.get(0)).expect("get row");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,23 +25,26 @@ embed_migrations!();
|
||||||
pub struct DbConn(SqliteConnection);
|
pub struct DbConn(SqliteConnection);
|
||||||
|
|
||||||
#[derive(Debug, serde::Serialize)]
|
#[derive(Debug, serde::Serialize)]
|
||||||
struct Context<'a> {
|
struct Context {
|
||||||
msg: Option<(&'a str, &'a str)>,
|
msg: Option<(String, String)>,
|
||||||
tasks: Vec<Task>
|
tasks: Vec<Task>
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Context<'a> {
|
impl Context {
|
||||||
pub fn err(conn: &DbConn, msg: &'a str) -> Context<'a> {
|
pub async fn err(conn: DbConn, msg: String) -> Context {
|
||||||
Context { msg: Some(("error", msg)), tasks: Task::all(conn).unwrap_or_default() }
|
Context {
|
||||||
|
msg: Some(("error".to_string(), msg)),
|
||||||
|
tasks: Task::all(conn).await.unwrap_or_default()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn raw(conn: &DbConn, msg: Option<(&'a str, &'a str)>) -> Context<'a> {
|
pub async fn raw(conn: DbConn, msg: Option<(String, String)>) -> Context {
|
||||||
match Task::all(conn) {
|
match Task::all(conn).await {
|
||||||
Ok(tasks) => Context { msg, tasks },
|
Ok(tasks) => Context { msg, tasks },
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error_!("DB Task::all() error: {}", e);
|
error_!("DB Task::all() error: {}", e);
|
||||||
Context {
|
Context {
|
||||||
msg: Some(("error", "Couldn't access the task database.")),
|
msg: Some(("error".to_string(), "Couldn't access the task database.".to_string())),
|
||||||
tasks: vec![]
|
tasks: vec![]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -50,11 +53,11 @@ impl<'a> Context<'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/", data = "<todo_form>")]
|
#[post("/", data = "<todo_form>")]
|
||||||
fn new(todo_form: Form<Todo>, conn: DbConn) -> Flash<Redirect> {
|
async fn new(todo_form: Form<Todo>, conn: DbConn) -> Flash<Redirect> {
|
||||||
let todo = todo_form.into_inner();
|
let todo = todo_form.into_inner();
|
||||||
if todo.description.is_empty() {
|
if todo.description.is_empty() {
|
||||||
Flash::error(Redirect::to("/"), "Description cannot be empty.")
|
Flash::error(Redirect::to("/"), "Description cannot be empty.")
|
||||||
} else if let Err(e) = Task::insert(todo, &conn) {
|
} else if let Err(e) = Task::insert(todo, conn).await {
|
||||||
error_!("DB insertion error: {}", e);
|
error_!("DB insertion error: {}", e);
|
||||||
Flash::error(Redirect::to("/"), "Todo could not be inserted due an internal error.")
|
Flash::error(Redirect::to("/"), "Todo could not be inserted due an internal error.")
|
||||||
} else {
|
} else {
|
||||||
|
@ -63,42 +66,48 @@ fn new(todo_form: Form<Todo>, conn: DbConn) -> Flash<Redirect> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[put("/<id>")]
|
#[put("/<id>")]
|
||||||
fn toggle(id: i32, conn: DbConn) -> Result<Redirect, Template> {
|
async fn toggle(id: i32, mut conn: DbConn) -> Result<Redirect, Template> {
|
||||||
Task::toggle_with_id(id, &conn)
|
// TODO
|
||||||
.map(|_| Redirect::to("/"))
|
let conn2 = conn.clone().await.unwrap();
|
||||||
.map_err(|e| {
|
match Task::toggle_with_id(id, conn).await {
|
||||||
|
Ok(_) => Ok(Redirect::to("/")),
|
||||||
|
Err(e) => {
|
||||||
error_!("DB toggle({}) error: {}", id, e);
|
error_!("DB toggle({}) error: {}", id, e);
|
||||||
Template::render("index", Context::err(&conn, "Failed to toggle task."))
|
Err(Template::render("index", Context::err(conn2, "Failed to toggle task.".to_string()).await))
|
||||||
})
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[delete("/<id>")]
|
#[delete("/<id>")]
|
||||||
fn delete(id: i32, conn: DbConn) -> Result<Flash<Redirect>, Template> {
|
async fn delete(id: i32, mut conn: DbConn) -> Result<Flash<Redirect>, Template> {
|
||||||
Task::delete_with_id(id, &conn)
|
// TODO
|
||||||
.map(|_| Flash::success(Redirect::to("/"), "Todo was deleted."))
|
let conn2 = conn.clone().await.unwrap();
|
||||||
.map_err(|e| {
|
match Task::delete_with_id(id, conn).await {
|
||||||
|
Ok(_) => Ok(Flash::success(Redirect::to("/"), "Todo was deleted.")),
|
||||||
|
Err(e) => {
|
||||||
error_!("DB deletion({}) error: {}", id, e);
|
error_!("DB deletion({}) error: {}", id, e);
|
||||||
Template::render("index", Context::err(&conn, "Failed to delete task."))
|
Err(Template::render("index", Context::err(conn2, "Failed to delete task.".to_string()).await))
|
||||||
})
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/")]
|
#[get("/")]
|
||||||
fn index(msg: Option<FlashMessage<'_, '_>>, conn: DbConn) -> Template {
|
async fn index(msg: Option<FlashMessage<'_, '_>>, conn: DbConn) -> Template {
|
||||||
Template::render("index", match msg {
|
let msg = msg.map(|m| (m.name().to_string(), m.msg().to_string()));
|
||||||
Some(ref msg) => Context::raw(&conn, Some((msg.name(), msg.msg()))),
|
Template::render("index", Context::raw(conn, msg).await)
|
||||||
None => Context::raw(&conn, None),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run_db_migrations(mut rocket: Rocket) -> Result<Rocket, Rocket> {
|
async fn run_db_migrations(mut rocket: Rocket) -> Result<Rocket, Rocket> {
|
||||||
let conn = DbConn::get_one(rocket.inspect().await).expect("database connection");
|
let conn = DbConn::get_one(rocket.inspect().await).await.expect("database connection");
|
||||||
match embedded_migrations::run(&*conn) {
|
conn.run(|c| {
|
||||||
Ok(()) => Ok(rocket),
|
match embedded_migrations::run(c) {
|
||||||
Err(e) => {
|
Ok(()) => Ok(rocket),
|
||||||
error!("Failed to run database migrations: {:?}", e);
|
Err(e) => {
|
||||||
Err(rocket)
|
error!("Failed to run database migrations: {:?}", e);
|
||||||
|
Err(rocket)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[launch]
|
#[launch]
|
||||||
|
|
|
@ -13,6 +13,8 @@ mod schema {
|
||||||
use self::schema::tasks;
|
use self::schema::tasks;
|
||||||
use self::schema::tasks::dsl::{tasks as all_tasks, completed as task_completed};
|
use self::schema::tasks::dsl::{tasks as all_tasks, completed as task_completed};
|
||||||
|
|
||||||
|
use crate::DbConn;
|
||||||
|
|
||||||
#[table_name="tasks"]
|
#[table_name="tasks"]
|
||||||
#[derive(serde::Serialize, Queryable, Insertable, Debug, Clone)]
|
#[derive(serde::Serialize, Queryable, Insertable, Debug, Clone)]
|
||||||
pub struct Task {
|
pub struct Task {
|
||||||
|
@ -27,32 +29,38 @@ pub struct Todo {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Task {
|
impl Task {
|
||||||
pub fn all(conn: &SqliteConnection) -> QueryResult<Vec<Task>> {
|
pub async fn all(conn: DbConn) -> QueryResult<Vec<Task>> {
|
||||||
all_tasks.order(tasks::id.desc()).load::<Task>(conn)
|
conn.run(|c| {
|
||||||
|
all_tasks.order(tasks::id.desc()).load::<Task>(c)
|
||||||
|
}).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the number of affected rows: 1.
|
/// Returns the number of affected rows: 1.
|
||||||
pub fn insert(todo: Todo, conn: &SqliteConnection) -> QueryResult<usize> {
|
pub async fn insert(todo: Todo, conn: DbConn) -> QueryResult<usize> {
|
||||||
let t = Task { id: None, description: todo.description, completed: false };
|
conn.run(|c| {
|
||||||
diesel::insert_into(tasks::table).values(&t).execute(conn)
|
let t = Task { id: None, description: todo.description, completed: false };
|
||||||
|
diesel::insert_into(tasks::table).values(&t).execute(c)
|
||||||
|
}).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the number of affected rows: 1.
|
/// Returns the number of affected rows: 1.
|
||||||
pub fn toggle_with_id(id: i32, conn: &SqliteConnection) -> QueryResult<usize> {
|
pub async fn toggle_with_id(id: i32, conn: DbConn) -> QueryResult<usize> {
|
||||||
let task = all_tasks.find(id).get_result::<Task>(conn)?;
|
conn.run(move |c| {
|
||||||
let new_status = !task.completed;
|
let task = all_tasks.find(id).get_result::<Task>(c)?;
|
||||||
let updated_task = diesel::update(all_tasks.find(id));
|
let new_status = !task.completed;
|
||||||
updated_task.set(task_completed.eq(new_status)).execute(conn)
|
let updated_task = diesel::update(all_tasks.find(id));
|
||||||
|
updated_task.set(task_completed.eq(new_status)).execute(c)
|
||||||
|
}).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the number of affected rows: 1.
|
/// Returns the number of affected rows: 1.
|
||||||
pub fn delete_with_id(id: i32, conn: &SqliteConnection) -> QueryResult<usize> {
|
pub async fn delete_with_id(id: i32, conn: DbConn) -> QueryResult<usize> {
|
||||||
diesel::delete(all_tasks.find(id)).execute(conn)
|
conn.run(move |c| diesel::delete(all_tasks.find(id)).execute(c)).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the number of affected rows.
|
/// Returns the number of affected rows.
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub fn delete_all(conn: &SqliteConnection) -> QueryResult<usize> {
|
pub async fn delete_all(conn: DbConn) -> QueryResult<usize> {
|
||||||
diesel::delete(all_tasks).execute(conn)
|
conn.run(|c| diesel::delete(all_tasks).execute(c)).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,7 @@ use super::task::Task;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use rand::{Rng, thread_rng, distributions::Alphanumeric};
|
use rand::{Rng, thread_rng, distributions::Alphanumeric};
|
||||||
|
|
||||||
use rocket::local::blocking::Client;
|
use rocket::local::asynchronous::Client;
|
||||||
use rocket::http::{Status, ContentType};
|
use rocket::http::{Status, ContentType};
|
||||||
|
|
||||||
// We use a lock to synchronize between tests so DB operations don't collide.
|
// We use a lock to synchronize between tests so DB operations don't collide.
|
||||||
|
@ -15,13 +15,16 @@ macro_rules! run_test {
|
||||||
(|$client:ident, $conn:ident| $block:expr) => ({
|
(|$client:ident, $conn:ident| $block:expr) => ({
|
||||||
let _lock = DB_LOCK.lock();
|
let _lock = DB_LOCK.lock();
|
||||||
|
|
||||||
let rocket = super::rocket();
|
rocket::async_test(async move {
|
||||||
let $client = Client::new(rocket).expect("Rocket client");
|
let rocket = super::rocket();
|
||||||
let db = super::DbConn::get_one($client.cargo());
|
let $client = Client::new(rocket).await.expect("Rocket client");
|
||||||
let $conn = db.expect("failed to get database connection for testing");
|
let db = super::DbConn::get_one($client.cargo()).await;
|
||||||
Task::delete_all(&$conn).expect("failed to delete all tasks for testing");
|
let mut $conn = db.expect("failed to get database connection for testing");
|
||||||
|
let delete_conn = $conn.clone().await.expect("failed to get a second database connection for testing");
|
||||||
|
Task::delete_all(delete_conn).await.expect("failed to delete all tasks for testing");
|
||||||
|
|
||||||
$block
|
$block
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,16 +32,17 @@ macro_rules! run_test {
|
||||||
fn test_insertion_deletion() {
|
fn test_insertion_deletion() {
|
||||||
run_test!(|client, conn| {
|
run_test!(|client, conn| {
|
||||||
// Get the tasks before making changes.
|
// Get the tasks before making changes.
|
||||||
let init_tasks = Task::all(&conn).unwrap();
|
let init_tasks = Task::all(conn.clone().await.unwrap()).await.unwrap();
|
||||||
|
|
||||||
// Issue a request to insert a new task.
|
// Issue a request to insert a new task.
|
||||||
client.post("/todo")
|
client.post("/todo")
|
||||||
.header(ContentType::Form)
|
.header(ContentType::Form)
|
||||||
.body("description=My+first+task")
|
.body("description=My+first+task")
|
||||||
.dispatch();
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
// Ensure we have one more task in the database.
|
// Ensure we have one more task in the database.
|
||||||
let new_tasks = Task::all(&conn).unwrap();
|
let new_tasks = Task::all(conn.clone().await.unwrap()).await.unwrap();
|
||||||
assert_eq!(new_tasks.len(), init_tasks.len() + 1);
|
assert_eq!(new_tasks.len(), init_tasks.len() + 1);
|
||||||
|
|
||||||
// Ensure the task is what we expect.
|
// Ensure the task is what we expect.
|
||||||
|
@ -47,10 +51,10 @@ fn test_insertion_deletion() {
|
||||||
|
|
||||||
// Issue a request to delete the task.
|
// Issue a request to delete the task.
|
||||||
let id = new_tasks[0].id.unwrap();
|
let id = new_tasks[0].id.unwrap();
|
||||||
client.delete(format!("/todo/{}", id)).dispatch();
|
client.delete(format!("/todo/{}", id)).dispatch().await;
|
||||||
|
|
||||||
// Ensure it's gone.
|
// Ensure it's gone.
|
||||||
let final_tasks = Task::all(&conn).unwrap();
|
let final_tasks = Task::all(conn).await.unwrap();
|
||||||
assert_eq!(final_tasks.len(), init_tasks.len());
|
assert_eq!(final_tasks.len(), init_tasks.len());
|
||||||
if final_tasks.len() > 0 {
|
if final_tasks.len() > 0 {
|
||||||
assert_ne!(final_tasks[0].description, "My first task");
|
assert_ne!(final_tasks[0].description, "My first task");
|
||||||
|
@ -65,18 +69,19 @@ fn test_toggle() {
|
||||||
client.post("/todo")
|
client.post("/todo")
|
||||||
.header(ContentType::Form)
|
.header(ContentType::Form)
|
||||||
.body("description=test_for_completion")
|
.body("description=test_for_completion")
|
||||||
.dispatch();
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
let task = Task::all(&conn).unwrap()[0].clone();
|
let task = Task::all(conn.clone().await.unwrap()).await.unwrap()[0].clone();
|
||||||
assert_eq!(task.completed, false);
|
assert_eq!(task.completed, false);
|
||||||
|
|
||||||
// Issue a request to toggle the task; ensure it is completed.
|
// Issue a request to toggle the task; ensure it is completed.
|
||||||
client.put(format!("/todo/{}", task.id.unwrap())).dispatch();
|
client.put(format!("/todo/{}", task.id.unwrap())).dispatch().await;
|
||||||
assert_eq!(Task::all(&conn).unwrap()[0].completed, true);
|
assert_eq!(Task::all(conn.clone().await.unwrap()).await.unwrap()[0].completed, true);
|
||||||
|
|
||||||
// Issue a request to toggle the task; ensure it's not completed again.
|
// Issue a request to toggle the task; ensure it's not completed again.
|
||||||
client.put(format!("/todo/{}", task.id.unwrap())).dispatch();
|
client.put(format!("/todo/{}", task.id.unwrap())).dispatch().await;
|
||||||
assert_eq!(Task::all(&conn).unwrap()[0].completed, false);
|
assert_eq!(Task::all(conn).await.unwrap()[0].completed, false);
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,7 +91,7 @@ fn test_many_insertions() {
|
||||||
|
|
||||||
run_test!(|client, conn| {
|
run_test!(|client, conn| {
|
||||||
// Get the number of tasks initially.
|
// Get the number of tasks initially.
|
||||||
let init_num = Task::all(&conn).unwrap().len();
|
let init_num = Task::all(conn.clone().await.unwrap()).await.unwrap().len();
|
||||||
let mut descs = Vec::new();
|
let mut descs = Vec::new();
|
||||||
|
|
||||||
for i in 0..ITER {
|
for i in 0..ITER {
|
||||||
|
@ -95,13 +100,14 @@ fn test_many_insertions() {
|
||||||
client.post("/todo")
|
client.post("/todo")
|
||||||
.header(ContentType::Form)
|
.header(ContentType::Form)
|
||||||
.body(format!("description={}", desc))
|
.body(format!("description={}", desc))
|
||||||
.dispatch();
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
// Record the description we choose for this iteration.
|
// Record the description we choose for this iteration.
|
||||||
descs.insert(0, desc);
|
descs.insert(0, desc);
|
||||||
|
|
||||||
// Ensure the task was inserted properly and all other tasks remain.
|
// Ensure the task was inserted properly and all other tasks remain.
|
||||||
let tasks = Task::all(&conn).unwrap();
|
let tasks = Task::all(conn.clone().await.unwrap()).await.unwrap();
|
||||||
assert_eq!(tasks.len(), init_num + i + 1);
|
assert_eq!(tasks.len(), init_num + i + 1);
|
||||||
|
|
||||||
for j in 0..i {
|
for j in 0..i {
|
||||||
|
@ -117,7 +123,8 @@ fn test_bad_form_submissions() {
|
||||||
// Submit an empty form. We should get a 422 but no flash error.
|
// Submit an empty form. We should get a 422 but no flash error.
|
||||||
let res = client.post("/todo")
|
let res = client.post("/todo")
|
||||||
.header(ContentType::Form)
|
.header(ContentType::Form)
|
||||||
.dispatch();
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
let mut cookies = res.headers().get("Set-Cookie");
|
let mut cookies = res.headers().get("Set-Cookie");
|
||||||
assert_eq!(res.status(), Status::UnprocessableEntity);
|
assert_eq!(res.status(), Status::UnprocessableEntity);
|
||||||
|
@ -128,7 +135,8 @@ fn test_bad_form_submissions() {
|
||||||
let res = client.post("/todo")
|
let res = client.post("/todo")
|
||||||
.header(ContentType::Form)
|
.header(ContentType::Form)
|
||||||
.body("description=")
|
.body("description=")
|
||||||
.dispatch();
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
let mut cookies = res.headers().get("Set-Cookie");
|
let mut cookies = res.headers().get("Set-Cookie");
|
||||||
assert!(cookies.any(|value| value.contains("error")));
|
assert!(cookies.any(|value| value.contains("error")));
|
||||||
|
@ -137,7 +145,8 @@ fn test_bad_form_submissions() {
|
||||||
let res = client.post("/todo")
|
let res = client.post("/todo")
|
||||||
.header(ContentType::Form)
|
.header(ContentType::Form)
|
||||||
.body("evil=smile")
|
.body("evil=smile")
|
||||||
.dispatch();
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
let mut cookies = res.headers().get("Set-Cookie");
|
let mut cookies = res.headers().get("Set-Cookie");
|
||||||
assert_eq!(res.status(), Status::UnprocessableEntity);
|
assert_eq!(res.status(), Status::UnprocessableEntity);
|
||||||
|
|
|
@ -205,7 +205,6 @@ request-local state to implement request timing.
|
||||||
[`FromRequest` request-local state]: @api/rocket/request/trait.FromRequest.html#request-local-state
|
[`FromRequest` request-local state]: @api/rocket/request/trait.FromRequest.html#request-local-state
|
||||||
[`Fairing`]: @api/rocket/fairing/trait.Fairing.html#request-local-state
|
[`Fairing`]: @api/rocket/fairing/trait.Fairing.html#request-local-state
|
||||||
|
|
||||||
<!-- TODO.async: rewrite? -->
|
|
||||||
## Databases
|
## Databases
|
||||||
|
|
||||||
Rocket includes built-in, ORM-agnostic support for databases. In particular,
|
Rocket includes built-in, ORM-agnostic support for databases. In particular,
|
||||||
|
@ -222,7 +221,7 @@ three simple steps:
|
||||||
|
|
||||||
1. Configure the databases in `Rocket.toml`.
|
1. Configure the databases in `Rocket.toml`.
|
||||||
2. Associate a request guard type and fairing with each database.
|
2. Associate a request guard type and fairing with each database.
|
||||||
3. Use the request guard to retrieve a connection in a handler.
|
3. Use the request guard to retrieve and use a connection in a handler.
|
||||||
|
|
||||||
Presently, Rocket provides built-in support for the following databases:
|
Presently, Rocket provides built-in support for the following databases:
|
||||||
|
|
||||||
|
@ -301,7 +300,7 @@ fn rocket() -> rocket::Rocket {
|
||||||
```
|
```
|
||||||
|
|
||||||
That's it! Whenever a connection to the database is needed, use your type as a
|
That's it! Whenever a connection to the database is needed, use your type as a
|
||||||
request guard:
|
request guard. The database can be accessed by calling the `run` method:
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
# #[macro_use] extern crate rocket;
|
# #[macro_use] extern crate rocket;
|
||||||
|
@ -315,9 +314,9 @@ request guard:
|
||||||
# type Logs = ();
|
# type Logs = ();
|
||||||
|
|
||||||
#[get("/logs/<id>")]
|
#[get("/logs/<id>")]
|
||||||
fn get_logs(conn: LogsDbConn, id: usize) -> Logs {
|
async fn get_logs(conn: LogsDbConn, id: usize) -> Logs {
|
||||||
# /*
|
# /*
|
||||||
logs::filter(id.eq(log_id)).load(&*conn)
|
conn.run(|c| logs::filter(id.eq(log_id)).load(c)).await
|
||||||
# */
|
# */
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
@ -329,6 +328,15 @@ fn get_logs(conn: LogsDbConn, id: usize) -> Logs {
|
||||||
syntax. Rocket does not provide an ORM. It is up to you to decide how to model
|
syntax. Rocket does not provide an ORM. It is up to you to decide how to model
|
||||||
your application's data.
|
your application's data.
|
||||||
|
|
||||||
|
! note
|
||||||
|
|
||||||
|
The database engines supported by `#[database]` are *synchronous*. Normally,
|
||||||
|
using such a database would block the thread of execution. To prevent this,
|
||||||
|
the `run()` function automatically uses a thread pool so that database access
|
||||||
|
does not interfere with other in-flight requests. See [Cooperative
|
||||||
|
Multitasking](../overview/#cooperative-multitasking) for more information on
|
||||||
|
why this is necessary.
|
||||||
|
|
||||||
If your application uses features of a database engine that are not available
|
If your application uses features of a database engine that are not available
|
||||||
by default, for example support for `chrono` or `uuid`, you may enable those
|
by default, for example support for `chrono` or `uuid`, you may enable those
|
||||||
features by adding them in `Cargo.toml` like so:
|
features by adding them in `Cargo.toml` like so:
|
||||||
|
|
Loading…
Reference in New Issue