diff --git a/Cargo.toml b/Cargo.toml index 74475d03..90e878fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "core/codegen_next/", "core/http/", "contrib/lib", + "contrib/codegen", "examples/cookies", "examples/errors", "examples/form_validation", diff --git a/contrib/codegen/Cargo.toml b/contrib/codegen/Cargo.toml new file mode 100644 index 00000000..83579635 --- /dev/null +++ b/contrib/codegen/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "rocket_contrib_codegen" +version = "0.4.0-dev" +authors = ["Sergio Benitez "] +description = "Procedural macros for the Rocket contrib libraries." +documentation = "https://api.rocket.rs/rocket_contrib_codegen/" +homepage = "https://rocket.rs" +repository = "https://github.com/SergioBenitez/Rocket" +readme = "../../../README.md" +keywords = ["rocket", "contrib", "code", "generation", "proc-macro"] +license = "MIT/Apache-2.0" + +# if publishing, add to config scripts +publish = false + +[features] +database_attribute = [] + +[lib] +proc-macro = true + +[dependencies] +quote = "0.6" +proc-macro2 = { version = "0.4", features = ["nightly"] } +syn = { version = "0.14", features = ["full", "extra-traits"] } diff --git a/contrib/codegen/src/database.rs b/contrib/codegen/src/database.rs new file mode 100644 index 00000000..c6eff17c --- /dev/null +++ b/contrib/codegen/src/database.rs @@ -0,0 +1,133 @@ +use proc_macro::{TokenStream, Diagnostic}; +use syn::{DataStruct, Fields, Data, Type, LitStr, DeriveInput, Ident, Visibility}; +use spanned::Spanned; + +type Result = ::std::result::Result; + +#[derive(Debug)] +struct DatabaseInvocation { + /// The name of the structure on which `#[database(..)] struct This(..)` was invoked. + type_name: Ident, + /// The visibility of the structure on which `#[database(..)] struct This(..)` was invoked. + visibility: Visibility, + /// The database name as passed in via #[database('database name')]. + db_name: String, + /// The entire structure that the `database` attribute was called on. + structure: DataStruct, + /// The type inside the structure: struct MyDb(ThisType). + connection_type: Type, +} + +const EXAMPLE: &str = "example: `struct MyDatabase(diesel::SqliteConnection);`"; +const ONLY_ON_STRUCTS_MSG: &str = "`database` attribute can only be used on structs"; +const ONLY_UNNAMED_FIELDS: &str = "`database` attribute can only be applied to structs with \ + exactly one unnamed field"; +const NO_GENERIC_STRUCTS: &str = "`database` attribute cannot be applied to a struct with a \ + generic type"; + +fn parse_invocation(attr: TokenStream, input: TokenStream) -> Result { + let attr_stream2 = ::proc_macro2::TokenStream::from(attr); + let attr_span = attr_stream2.span(); + let string_lit = ::syn::parse2::(attr_stream2) + .map_err(|_| attr_span.error("expected string literal"))?; + + let input = ::syn::parse::(input).unwrap(); + if !input.generics.params.is_empty() { + return Err(input.span().error(NO_GENERIC_STRUCTS)); + } + + let structure = match input.data { + Data::Struct(s) => s, + _ => return Err(input.span().error(ONLY_ON_STRUCTS_MSG)) + }; + + let inner_type = match structure.fields { + Fields::Unnamed(ref fields) if fields.unnamed.len() == 1 => { + let first = fields.unnamed.first().expect("checked length"); + first.value().ty.clone() + } + _ => return Err(structure.fields.span().error(ONLY_UNNAMED_FIELDS).help(EXAMPLE)) + }; + + Ok(DatabaseInvocation { + type_name: input.ident, + visibility: input.vis, + db_name: string_lit.value(), + structure: structure, + connection_type: inner_type, + }) +} + +pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result { + let invocation = parse_invocation(attr, input)?; + + let connection_type = &invocation.connection_type; + let database_name = &invocation.db_name; + let request_guard_type = &invocation.type_name; + let request_guard_vis = &invocation.visibility; + let pool_type = Ident::new(&format!("{}Pool", request_guard_type), request_guard_type.span()); + + let tokens = quote! { + #request_guard_vis struct #request_guard_type( + pub ::rocket_contrib::databases::r2d2::PooledConnection<<#connection_type as ::rocket_contrib::databases::Poolable>::Manager> + ); + #request_guard_vis struct #pool_type( + ::rocket_contrib::databases::r2d2::Pool<<#connection_type as ::rocket_contrib::databases::Poolable>::Manager> + ); + + impl #request_guard_type { + pub fn fairing() -> impl ::rocket::fairing::Fairing { + use ::rocket_contrib::databases::Poolable; + + ::rocket::fairing::AdHoc::on_attach(|rocket| { + let pool = ::rocket_contrib::databases::database_config(#database_name, rocket.config()) + .map(#connection_type::pool); + + match pool { + Ok(Ok(p)) => Ok(rocket.manage(#pool_type(p))), + Err(config_error) => { + ::rocket::logger::log_err(&format!("Error while instantiating database: '{}': {}", #database_name, config_error)); + Err(rocket) + }, + Ok(Err(pool_error)) => { + ::rocket::logger::log_err(&format!("Error initializing pool for '{}': {:?}", #database_name, pool_error)); + Err(rocket) + }, + } + }) + } + + /// Retrieves a connection of type `Self` from the `rocket` instance. Returns `Some` as long as + /// `Self::fairing()` has been attached and there is at least one connection in the pool. + pub fn get_one(rocket: &::rocket::Rocket) -> Option { + rocket.state::<#pool_type>() + .and_then(|pool| pool.0.get().ok()) + .map(#request_guard_type) + } + } + + impl ::std::ops::Deref for #request_guard_type { + type Target = #connection_type; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl<'a, 'r> ::rocket::request::FromRequest<'a, 'r> for #request_guard_type { + type Error = (); + + fn from_request(request: &'a ::rocket::request::Request<'r>) -> ::rocket::request::Outcome { + let pool = request.guard::<::rocket::State<#pool_type>>()?; + + match pool.0.get() { + Ok(conn) => ::rocket::Outcome::Success(#request_guard_type(conn)), + Err(_) => ::rocket::Outcome::Failure((::rocket::http::Status::ServiceUnavailable, ())), + } + } + } + }; + + Ok(tokens.into()) +} diff --git a/contrib/codegen/src/lib.rs b/contrib/codegen/src/lib.rs new file mode 100644 index 00000000..5c84ba51 --- /dev/null +++ b/contrib/codegen/src/lib.rs @@ -0,0 +1,47 @@ +#![feature(proc_macro_span, proc_macro_diagnostic)] +#![recursion_limit="256"] + +//! # Rocket Contrib - Code Generation +//! This crate implements the code generation portion of the Rocket Contrib +//! crate. This is for officially sanctioned contributor libraries that require +//! code generation of some kind. +//! +//! This crate includes custom derives and procedural macros and will expand +//! as-needed if future `rocket_contrib` features require code generation +//! facilities. +//! +//! ## Procedural Macros +//! +//! This crate implements the following procedural macros: +//! +//! * **databases** +//! +//! The syntax for the `databases` macro is: +//! +//!
+//! macro := database(DATABASE_NAME)
+//! DATABASE_NAME := (string literal)
+//! 
+ +extern crate syn; +extern crate proc_macro; +extern crate proc_macro2; +#[macro_use] extern crate quote; + +mod spanned; + +#[cfg(feature = "database_attribute")] +mod database; + +#[allow(dead_code)] +use proc_macro::TokenStream; + +#[cfg(feature = "database_attribute")] +#[proc_macro_attribute] +/// The procedural macro for the `databases` annotation. +pub fn database(attr: TokenStream, input: TokenStream) -> TokenStream { + ::database::database_attr(attr, input).unwrap_or_else(|diag| { + diag.emit(); + TokenStream::new() + }) +} diff --git a/contrib/codegen/src/spanned.rs b/contrib/codegen/src/spanned.rs new file mode 100644 index 00000000..1b8e7a1b --- /dev/null +++ b/contrib/codegen/src/spanned.rs @@ -0,0 +1,29 @@ +use proc_macro::Span; + +use quote::ToTokens; + +pub trait Spanned { + fn span(&self) -> Span; +} + +// FIXME: Remove this once proc_macro's stabilize. +impl Spanned for T { + fn span(&self) -> Span { + let token_stream = self.into_token_stream(); + let mut iter = token_stream.into_iter(); + let mut span = match iter.next() { + Some(tt) => tt.span().unstable(), + None => { + return Span::call_site(); + } + }; + + for tt in iter { + if let Some(joined) = span.join(tt.span().unstable()) { + span = joined; + } + } + + span + } +} diff --git a/contrib/lib/Cargo.toml b/contrib/lib/Cargo.toml index 70bda010..1ada1803 100644 --- a/contrib/lib/Cargo.toml +++ b/contrib/lib/Cargo.toml @@ -17,6 +17,16 @@ msgpack = ["serde", "rmp-serde"] tera_templates = ["tera", "templates"] handlebars_templates = ["handlebars", "templates"] static_files = [] +database_pool_codegen = ["rocket_contrib_codegen", "rocket_contrib_codegen/database_attribute"] +database_pool = ["r2d2", "database_pool_codegen"] +diesel_pg_pool = ["database_pool", "diesel/postgres", "diesel/r2d2"] +diesel_sqlite_pool = ["database_pool", "diesel/sqlite", "diesel/r2d2"] +diesel_mysql_pool = ["database_pool", "diesel/mysql", "diesel/r2d2"] +postgres_pool = ["database_pool", "postgres", "r2d2_postgres"] +mysql_pool = ["database_pool", "mysql", "r2d2_mysql"] +sqlite_pool = ["database_pool", "rusqlite", "r2d2_sqlite"] +cypher_pool = ["database_pool", "rusted_cypher", "r2d2_cypher"] +redis_pool = ["database_pool", "redis", "r2d2_redis"] # Internal use only. templates = ["serde", "serde_json", "glob"] @@ -38,6 +48,23 @@ handlebars = { version = "1.0", optional = true } glob = { version = "0.2", optional = true } tera = { version = "0.11", optional = true } +# Database dependencies +diesel = { version = "1.0", default-features = false, optional = true } +postgres = { version = "0.15", optional = true } +r2d2 = { version = "0.8", optional = true } +r2d2_postgres = { version = "0.14", optional = true } +mysql = { version = "14", optional = true } +r2d2_mysql = { version = "9", optional = true } +rusqlite = { version = "0.13.0", optional = true } +r2d2_sqlite = { version = "0.5", optional = true } +rusted_cypher = { version = "1", optional = true } +r2d2_cypher = { version = "0.4", optional = true } +redis = { version = "0.8", optional = true } +r2d2_redis = { version = "0.7", optional = true } + +# Contrib codegen dependencies +rocket_contrib_codegen = { path = "../codegen", optional = true } + [dev-dependencies] rocket_codegen = { version = "0.4.0-dev", path = "../../core/codegen" } diff --git a/contrib/lib/src/databases.rs b/contrib/lib/src/databases.rs new file mode 100644 index 00000000..55f24774 --- /dev/null +++ b/contrib/lib/src/databases.rs @@ -0,0 +1,917 @@ +//! # Overview +//! +//! This module provides traits, utilities, and a procedural macro that allows +//! you to easily connect your Rocket application to databases through +//! connection pools. A _database connection pool_ is a data structure that +//! maintains active database connections for later use in the application. +//! This implementation of connection pooling support is based on +//! [`r2d2`](https://crates.io/crates/r2d2) and exposes connections through +//! [request guards](../../rocket/request/trait.FromRequest.html). Databases are +//! individually configured through Rocket's regular configuration mechanisms: a +//! `Rocket.toml` file, environment variables, or procedurally. +//! +//! Connecting your Rocket application to a database using this library occurs +//! in three simple steps: +//! +//! 1. Configure your databases in `Rocket.toml`. +//! (see [Configuration](#configuration)) +//! 2. Associate a request guard type and fairing with each database. +//! (see [Guard Types](#guard-types)) +//! 3. Use the request guard to retrieve a connection in a handler. +//! (see [Handlers](#handlers)) +//! +//! For a list of supported databases, see [Provided Databases](#provided). +//! This support can be easily extended by implementing the +//! [`Poolable`](trait.Poolable.html) trait. See [Extending](#extending) +//! for more. +//! +//! The next section provides a complete but un-detailed example of these steps +//! in actions. The sections following provide more detail for each component. +//! +//! ## Example +//! +//! Before using this library, the `database_pool` feature in `rocket_contrib` +//! must be enabled: +//! +//! ```toml +//! [dependencies.rocket_contrib] +//! version = "0.4.0-dev" +//! default-features = false +//! features = ["database_pool", "diesel_sqlite_pool"] +//! ``` +//! +//! In `Rocket.toml` or the equivalent via environment variables: +//! +//! ```toml +//! [global.databases] +//! sqlite_logs = { url = "/path/to/database.sqlite" } +//! ``` +//! +//! In your application's source code, one-time: +//! +//! ```rust,ignore +//! #![feature(use_extern_macros)] +//! extern crate rocket; +//! extern crate rocket_contrib; +//! +//! use rocket_contrib::databases::{database, diesel}; +//! +//! #[database("sqlite_logs")] +//! struct LogsDbConn(diesel::SqliteConnection); +//! +//! fn main() { +//! rocket::ignite() +//! .attach(LogsDbConn::fairing()) +//! .launch(); +//! } +//! ``` +//! +//! Whenever a connection to the database is needed: +//! +//! ```rust,ignore +//! #[get("/logs/")] +//! fn get_logs(conn: LogsDbConn, id: LogId) -> Result { +//! Logs::by_id(&conn, id) +//! } +//! ``` +//! +//! # Usage +//! +//! ## Configuration +//! +//! There are a few ways to configure your database connection. You can use the +//! `Rocket.toml` file, you can build it yourself procedurally via the +//! `rocket::custom()` method, or through environment variables. +//! +//! ### Configuring via `Rocket.toml` +//! +//! The following examples are all valid ways of configuring your database via +//! the `Rocket.toml` file. +//! +//! The basic structure includes attaching a key to the `global.databases` table +//! and including the __required__ keys `url` and `pool_size`. Additional +//! options that can be added to the table vary by adapter and are referenced +//! below in the [Supported Databases](#provided) section. +//! +//! ```toml +//! [global.databases] +//! my_database = { url = "database.sqlite", pool_size = 10 } +//! +//! [[global.databases.other_database]] +//! url = "mysql://root:root@localhost/other_database +//! pool_size = 25 +//! ``` +//! +//! ### Configuring procedurally +//! +//! It's also possible to procedurally configure your database via the +//! `rocket::custom()` method. Below is an example of doing this: +//! +//! ```rust,ignore +//! extern crate rocket; +//! +//! use std::io::Error; +//! use std::collections::HashMap; +//! use rocket::config::{Config, Environment, Value}; +//! +//! fn main() { +//! let mut database_config = HashMap::new(); +//! let mut databases = HashMap::new(); +//! +//! database_config.insert("url", Value::from("database.sqlite")); +//! databases.insert("my_db", Value::from(database_config)); +//! +//! let config = Config::build(Environment::Development) +//! .extra("databases", databases) +//! .finalize() +//! .unwrap(); +//! +//! rocket::custom(config).launch(); +//! } +//! ``` +//! +//! ### Configuring via Environment Variable +//! +//! The final way to configure your databases is via an environment variable. +//! Following the syntax laid out in the guide on [Environment Variables](https://rocket.rs/guide/configuration/#environment-variables), +//! you can configure your database this way. Below is an example +//! +//! ```bash +//! ROCKET_DATABASES={my_db={url="db.sqlite"}} +//! ``` +//! +//! ## Guard Types +//! +//! The included database support generates request guard types that can be used +//! with Rocket handlers. In order to associate a configured database with a +//! type, you need to use the `database` procedural macro: +//! +//! ```rust +//! # #![feature(use_extern_macros)] +//! # extern crate rocket; +//! # extern crate rocket_contrib; +//! # use rocket_contrib::databases::{database, diesel}; +//! +//! #[database("my_db")] +//! struct MyDatabase(diesel::SqliteConnection); +//! ``` +//! +//! From there, the macro will generate code to turn your defined type into a +//! valid request guard type. The interior type must have an implementation of +//! the [`Poolable` trait](trait.Poolable.html). The trait implements methods +//! on the interior type that are used by the generated code to spin up a +//! connection pool. The trait can be used to extend other connection types that +//! aren't supported in this library. See the section on [Extending](#extending) +//! for more information. +//! +//! The generated code will give your defined type two methods, `get_one` and +//! `fairing`, as well as implementations of the [`FromRequest`](../../rocket/request/trait.FromRequest.html) +//! and [`Deref`](../../std/ops/trait.Deref.html) traits. +//! +//! The `fairing` method will allow you to attach your database type to the +//! application state via the method call. You __will need__ to call the +//! `fairing` method on your type in order to be able to retrieve connections +//! in your request guards. +//! +//! Below is an example: +//! +//! ```rust,ignore +//! # #![feature(use_extern_macros)] +//! # +//! # extern crate rocket; +//! # extern crate rocket_contrib; +//! # +//! # use std::collections::HashMap; +//! # use rocket::config::{Config, Environment, Value}; +//! # use rocket_contrib::databases::{database, diesel}; +//! # +//! #[database("my_db")] +//! struct MyDatabase(diesel::SqliteConnection); +//! +//! fn main() { +//! # let mut db_config = HashMap::new(); +//! # let mut databases = HashMap::new(); +//! # +//! # db_config.insert("url", Value::from("database.sqlite")); +//! # db_config.insert("pool_size", Value::from(10)); +//! # databases.insert("my_db", Value::from(db_config)); +//! # +//! # let config = Config::build(Environment::Development) +//! # .extra("databases", databases) +//! # .finalize() +//! # .unwrap(); +//! # +//! rocket::custom(config) +//! .attach(MyDatabase::fairing()); // Required! +//! .launch(); +//! } +//! ``` +//! +//! ## Handlers +//! +//! For request handlers, you should use the database type you defined in your +//! code as a request guard. Because of the `FromRequest` implementation that's +//! generated at compile-time, you can use this type in such a way. For example: +//! +//! ```rust,ignore +//! #[database("my_db") +//! struct MyDatabase(diesel::MysqlConnection); +//! ... +//! #[get("/")] +//! fn my_handler(conn: MyDatabase) { +//! ... +//! } +//! ``` +//! +//! Additionally, because of the `Deref` implementation, you can dereference +//! the database type in order to access the inner connection type. For example: +//! +//! ```rust,ignore +//! #[get("/")] +//! fn my_handler(conn: MyDatabase) { +//! ... +//! Thing::load(&conn); +//! ... +//! } +//! ``` +//! +//! Under the hood, the dereferencing of your type is returning the interior +//! type of your connection: +//! +//! ```rust,ignore +//! &self.0 +//! ``` +//! +//! This section should be simple. It should cover: +//! +//! * The fact that `MyType` is not a request guard, and you can use it. +//! * The `Deref` impl and what it means for using `&my_conn`. +//! +//! # Database Support +//! +//! This library provides built-in support for many popular databases and their +//! corresponding drivers. It also makes extending this support simple. +//! +//! ## Provided +//! +//! The list below includes all presently supported database adapters, their +//! corresponding [`Poolable`] type, and any special considerations for +//! configuration, if any. +//! +//! | Database Kind | Driver | `Poolable` Type | Feature | Notes | +//! | -- ------------- | ----------------------- | ------------------------- | --------------------- | ----- | +//! | MySQL | [Diesel](https://diesel.rs) | [`diesel::MysqlConnection`](http://docs.diesel.rs/diesel/mysql/struct.MysqlConnection.html) | `diesel_mysql_pool` | None | +//! | MySQL | [`rust-mysql-simple`](https://github.com/blackbeam/rust-mysql-simple) | [`mysql::conn`](https://docs.rs/mysql/14.0.0/mysql/struct.Conn.html) | `mysql_pool` | None | +//! | Postgres | [Diesel](https://diesel.rs) | [`diesel::PgConnection`](http://docs.diesel.rs/diesel/pg/struct.PgConnection.html) | `diesel_postgres_pool` | None | +//! | Postgres | [Rust-Postgres](https://github.com/sfackler/rust-postgres) | [`postgres::Connection`](https://docs.rs/postgres/0.15.2/postgres/struct.Connection.html) | `postgres_pool` | None | +//! | Sqlite | [Diesel](https://diesel.rs) | [`diesel::SqliteConnection`](http://docs.diesel.rs/diesel/prelude/struct.SqliteConnection.html) | `diesel_sqlite_pool` | None | +//! | Sqlite | [`Rustqlite`](https://github.com/jgallagher/rusqlite) | [`rusqlite::Connection`](https://docs.rs/rusqlite/0.13.0/rusqlite/struct.Connection.html) | `sqlite_pool` | None | +//! | Neo4j | [`rusted_cypher`](https://github.com/livioribeiro/rusted-cypher) | [`rusted_cypher::GraphClient`](https://docs.rs/rusted_cypher/1.1.0/rusted_cypher/graph/struct.GraphClient.html) | `cypher_pool` | None | +//! | Redis | [`Redis-rs`](https://github.com/mitsuhiko/redis-rs) | [`redis::Connection`](https://docs.rs/redis/0.9.0/redis/struct.Connection.html) | `redis_pool` | None | +//! +//! ### How to use the table +//! The above table lists all the supported database adapters in this library. +//! In order to use particular `Poolable` type that's included in this library, +//! you must first enable the feature listed in the 'Feature' column. The inner +//! type you should use for your database type should be what's listed in the +//! corresponding `Poolable` Type column. +//! +//! ## Extending +//! +//! Extending Rocket's support to your own custom database adapter (or other +//! database-like struct that can be pooled by r2d2) is as easy as implementing +//! the `Poolable` trait for your own type. See the documentation for the +//! [`Poolable` trait](trait.Poolable.html) for more details on how to implement +//! it and extend your type for use with Rocket's database pooling feature. + +pub extern crate r2d2; + +use std::collections::BTreeMap; +use std::fmt::{self, Display, Formatter}; +use std::marker::{Send, Sized}; + +use rocket::config::{self, Value}; + +pub use rocket_contrib_codegen::database; + +use self::r2d2::ManageConnection; + +#[cfg(any(feature = "diesel_sqlite_pool", feature = "diesel_postgres_pool", feature = "diesel_mysql_pool"))] +pub extern crate diesel; + +#[cfg(feature = "postgres_pool")] +pub extern crate postgres; +#[cfg(feature = "postgres_pool")] +pub extern crate r2d2_postgres; + +#[cfg(feature = "mysql_pool")] +pub extern crate mysql; +#[cfg(feature = "mysql_pool")] +pub extern crate r2d2_mysql; + +#[cfg(feature = "sqlite_pool")] +pub extern crate rusqlite; +#[cfg(feature = "sqlite_pool")] +pub extern crate r2d2_sqlite; + +#[cfg(feature = "cypher_pool")] +pub extern crate rusted_cypher; +#[cfg(feature = "cypher_pool")] +pub extern crate r2d2_cypher; + +#[cfg(feature = "redis_pool")] +pub extern crate redis; +#[cfg(feature = "redis_pool")] +pub extern crate r2d2_redis; + +/// A struct containing database configuration options from some configuration. +/// +/// For the following configuration: +/// +/// ```toml +/// [[global.databases.my_database]] +/// url = "postgres://root:root@localhost/my_database +/// pool_size = 10 +/// certs = "sample_cert.pem" +/// key = "key.pem" +/// ``` +/// +/// The following structure would be generated after calling +/// `database_config("my_database", &some_config)`: +/// +/// ```ignore +/// DatabaseConfig { +/// url: "dummy_db.sqlite", +/// pool_size: 10, +/// extras: { +/// "certs": String("certs.pem"), +/// "key": String("key.pem") +/// } +/// } +/// ``` +#[derive(Debug, Clone, PartialEq)] +pub struct DatabaseConfig<'a> { + /// The connection URL specified in the Rocket configuration. + pub url: &'a str, + /// The size of the pool to be initialized. Defaults to the number of + /// Rocket workers. + pub pool_size: u32, + /// Any extra options that are included in the configuration, **excluding** + /// the url and pool_size. + pub extras: BTreeMap, +} + +/// A wrapper around `r2d2::Error`s or a custom database error type. This type +/// is mostly relevant to implementors of the [Poolable](trait.Poolable.html) +/// trait. +/// +/// Example usages of this type are in the `Poolable` implementations that ship +/// with `rocket_contrib`. +#[derive(Debug)] +pub enum DbError { + /// The custom error type to wrap alongside `r2d2::Error`. + Custom(T), + /// The error returned by an r2d2 pool. + PoolError(r2d2::Error), +} + +/// The error type for fetching the DatabaseConfig +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DatabaseConfigError { + /// Returned when the `[[global.databases]]` key is missing or empty from + /// the loaded configuration. + MissingTable, + /// Returned when the database configuration key is missing from the active + /// configuration. + MissingKey, + /// Returned when the configuration associated with the key isn't in the + /// expected [Table](../../rocket/config/type.Table.html) format. + MalformedConfiguration, + /// Returned when the `url` field is missing. + MissingUrl, + /// Returned when the `url` field is of the wrong type. + MalformedUrl, + /// Returned when the `pool_size` exceeds `u32::max_value()` or is negative. + InvalidPoolSize(i64), +} + +/// This method retrieves the database configuration from the loaded +/// configuration and returns a [`DatabaseConfig`](struct.DatabaseConfig.html) +/// struct. +/// +/// # Example: +/// +/// Given the following configuration: +/// +/// ```toml +/// [[global.databases]] +/// my_db = { url = "db/db.sqlite", pool_size = 25 } +/// my_other_db = { url = "mysql://root:root@localhost/database" } +/// ``` +/// +/// Calling the `database_config` method will return the +/// [`DatabaseConfig`](struct.DatabaseConfig.html) structure for any valid +/// configuration key. See the example code below. +/// +/// ```rust +/// # extern crate rocket; +/// # extern crate rocket_contrib; +/// # +/// # use std::{collections::BTreeMap, mem::drop}; +/// # use rocket::{fairing::AdHoc, config::{Config, Environment, Value}}; +/// use rocket_contrib::databases::{database_config, DatabaseConfigError}; +/// +/// # let mut databases = BTreeMap::new(); +/// # +/// # let mut my_db = BTreeMap::new(); +/// # my_db.insert("url".to_string(), Value::from("db/db.sqlite")); +/// # my_db.insert("pool_size".to_string(), Value::from(25)); +/// # +/// # let mut my_other_db = BTreeMap::new(); +/// # my_other_db.insert("url".to_string(), Value::from("mysql://root:root@localhost/database")); +/// # +/// # databases.insert("my_db".to_string(), Value::from(my_db)); +/// # databases.insert("my_other_db".to_string(), Value::from(my_other_db)); +/// # +/// # let config = Config::build(Environment::Development).extra("databases", databases).expect("custom config okay"); +/// # +/// # rocket::custom(config).attach(AdHoc::on_attach(|rocket| { +/// # // HACK: This is a dirty hack required to be able to make this work +/// # let thing = { +/// # let rocket_config = rocket.config(); +/// let config = database_config("my_db", rocket_config).expect("my_db config okay"); +/// assert_eq!(config.url, "db/db.sqlite"); +/// assert_eq!(config.pool_size, 25); +/// +/// let other_config = database_config("my_other_db", rocket_config).expect("my_other_db config okay"); +/// assert_eq!(other_config.url, "mysql://root:root@localhost/database"); +/// +/// let error = database_config("invalid_db", rocket_config).unwrap_err(); +/// assert_eq!(error, DatabaseConfigError::MissingKey); +/// # +/// # 10 +/// # }; +/// # +/// # Ok(rocket) +/// # })); +/// ``` +pub fn database_config<'a>( + name: &str, + from: &'a config::Config +) -> Result, DatabaseConfigError> { + // Find the first `databases` config that's a table with a key of 'name' + // equal to `name`. + let connection_config = from.get_table("databases") + .map_err(|_| DatabaseConfigError::MissingTable)? + .get(name) + .ok_or(DatabaseConfigError::MissingKey)? + .as_table() + .ok_or(DatabaseConfigError::MalformedConfiguration)?; + + let maybe_url = connection_config.get("url") + .ok_or(DatabaseConfigError::MissingUrl)?; + + let url = maybe_url.as_str().ok_or(DatabaseConfigError::MalformedUrl)?; + + let pool_size = connection_config.get("pool_size") + .and_then(Value::as_integer) + .unwrap_or(from.workers as i64); + + if pool_size < 1 || pool_size > u32::max_value() as i64 { + return Err(DatabaseConfigError::InvalidPoolSize(pool_size)); + } + + let mut extras = connection_config.clone(); + extras.remove("url"); + extras.remove("pool_size"); + + Ok(DatabaseConfig { url, pool_size: pool_size as u32, extras: extras }) +} + +impl<'a> Display for DatabaseConfigError { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match self { + DatabaseConfigError::MissingTable => { + write!(f, "A table named `databases` was not found for this configuration") + }, + DatabaseConfigError::MissingKey => { + write!(f, "An entry in the `databases` table was not found for this key") + }, + DatabaseConfigError::MalformedConfiguration => { + write!(f, "The configuration for this database is malformed") + } + DatabaseConfigError::MissingUrl => { + write!(f, "The connection URL is missing for this database") + }, + DatabaseConfigError::MalformedUrl => { + write!(f, "The specified connection URL is malformed") + }, + DatabaseConfigError::InvalidPoolSize(invalid_size) => { + write!(f, "'{}' is not a valid value for `pool_size`", invalid_size) + }, + } + } +} + +/// Trait implemented by database adapters to allow for r2d2 connection pools to +/// be easily created. +/// +/// # Provided Implementations +/// +/// Rocket Contrib implements `Poolable` on several common database adapters. +/// The provided implementations are listed here. +/// +/// * **diesel::MysqlConnection** +/// +/// * **diesel::PgConnection** +/// +/// * **diesel::SqliteConnection** +/// +/// * **postgres::Connection** +/// +/// * **mysql::Conn** +/// +/// * **rusqlite::Connection** +/// +/// * **rusted_cypher::GraphClient** +/// +/// * **redis::Connection** +/// +/// # Implementation Guide +/// +/// As a r2d2-compatible database (or other resource) adapter provider, +/// implementing `Poolable` in your own library will enable Rocket users to +/// consume your adapter with its built-in connection pooling primitives. +/// +/// ## Example +/// +/// This example assumes a `FooConnectionManager` implementing the +/// `ManageConnection`trait required by r2d2. This connection manager abstracts +/// over a pool of `FooClient` connections. +/// +/// Given the following definition of the client and connection manager: +/// +/// ```rust,ignore +/// struct FooClient { ... }; +/// +/// impl FooClient { +/// pub fn new(...) -> Result { +/// ... +/// } +/// } +/// +/// struct FooConnectionManager { ... }; +/// +/// impl FooConnectionManager { +/// pub fn new(...) -> Result { +/// ... +/// } +/// } +/// ``` +/// +/// In order to allow for Rocket Contrib to generate the required code to +/// automatically provision a r2d2 connection pool into application state, the +/// `Poolable` trait needs to be implemented for the connection type. +/// +/// Given the above definitions, the following would be a valid implementation +/// of the `Poolable` trait: +/// +/// ```rust,ignore +/// impl Poolable for FooClient { +/// type Manager = FooConnectionManager; +/// type Error = DbError; +/// +/// fn pool(config: DatabaseConfig) -> Result, Self::Error> { +/// let manager = FooConnectionManager::new(config.url) +/// .map_err(DbError::Custom)?; +/// +/// r2d2::Pool::builder().max_size(config.pool_size).build(manager) +/// .map_err(DbError::PoolError) +/// } +/// } +/// ``` +/// +/// In the above example, the connection manager is failable and returns the the +/// `FooClient`'s error type. Since the error type can diverge from a simple +/// r2d2 pool error, the [`DbError`](enum.DbError.html) wrapper is used. This +/// error type is defined as part of the associated type in the `Poolable` trait +/// definition. +/// +/// Additionally, you'll notice that the `pool` method of the trait is used to +/// to create the connection manager and the pool. This method returns a +/// `Result` containing an r2d2 pool monomorphized to the `Manager` associated +/// type in the trait definition, or containing the `Error` associated type. +/// +/// In the event that the connection manager isn't failable (as is the case in +/// Diesel's r2d2 connection manager, for example), the associated error type +/// for the `Poolable` implementation can simply be `r2d2::Error` as this is the +/// only error that can be returned by the `pool` method. You can refer to the +/// included implementations of `Poolable` in the `rocket_contrib::databases` +/// module for concrete examples. +/// +pub trait Poolable: Send + Sized + 'static { + /// The associated connection manager for the given connection type. + type Manager: ManageConnection; + /// The associated error type in the event that constructing the connection + /// manager and/or the connection pool fails + type Error; + + /// Creates an r2d2 connection pool from the provided Manager associated + /// type and returns the pool or the error associated with the trait + /// implementation. + fn pool(config: DatabaseConfig) -> Result, Self::Error>; +} + +#[cfg(feature = "diesel_sqlite_pool")] +impl Poolable for diesel::SqliteConnection { + type Manager = diesel::r2d2::ConnectionManager; + type Error = r2d2::Error; + + fn pool(config: DatabaseConfig) -> Result, Self::Error> { + let manager = diesel::r2d2::ConnectionManager::new(config.url); + r2d2::Pool::builder().max_size(config.pool_size).build(manager) + } +} + +#[cfg(feature = "diesel_pg_pool")] +impl Poolable for diesel::PgConnection { + type Manager = diesel::r2d2::ConnectionManager; + type Error = r2d2::Error; + + fn pool(config: DatabaseConfig) -> Result, Self::Error> { + let manager = diesel::r2d2::ConnectionManager::new(config.url); + r2d2::Pool::builder().max_size(config.pool_size).build(manager) + } +} + +#[cfg(feature = "diesel_mysql_pool")] +impl Poolable for diesel::MysqlConnection { + type Manager = diesel::r2d2::ConnectionManager; + type Error = r2d2::Error; + + fn pool(config: DatabaseConfig) -> Result, Self::Error> { + let manager = diesel::r2d2::ConnectionManager::new(config.url); + r2d2::Pool::builder().max_size(config.pool_size).build(manager) + } +} + +// TODO: Come up with a way to handle TLS +#[cfg(feature = "postgres_pool")] +impl Poolable for postgres::Connection { + type Manager = r2d2_postgres::PostgresConnectionManager; + type Error = DbError; + + fn pool(config: DatabaseConfig) -> Result, Self::Error> { + let manager = r2d2_postgres::PostgresConnectionManager::new(config.url, r2d2_postgres::TlsMode::None) + .map_err(DbError::Custom)?; + + r2d2::Pool::builder().max_size(config.pool_size).build(manager) + .map_err(DbError::PoolError) + } +} + +#[cfg(feature = "mysql_pool")] +impl Poolable for mysql::Conn { + type Manager = r2d2_mysql::MysqlConnectionManager; + type Error = r2d2::Error; + + fn pool(config: DatabaseConfig) -> Result, Self::Error> { + let opts = mysql::OptsBuilder::from_opts(config.url); + let manager = r2d2_mysql::MysqlConnectionManager::new(opts); + r2d2::Pool::builder().max_size(config.pool_size).build(manager) + } +} + +#[cfg(feature = "sqlite_pool")] +impl Poolable for rusqlite::Connection { + type Manager = r2d2_sqlite::SqliteConnectionManager; + type Error = r2d2::Error; + + fn pool(config: DatabaseConfig) -> Result, Self::Error> { + let manager = r2d2_sqlite::SqliteConnectionManager::file(config.url); + + r2d2::Pool::builder().max_size(config.pool_size).build(manager) + } +} + +#[cfg(feature = "cypher_pool")] +impl Poolable for rusted_cypher::GraphClient { + type Manager = r2d2_cypher::CypherConnectionManager; + type Error = r2d2::Error; + + fn pool(config: DatabaseConfig) -> Result, Self::Error> { + let manager = r2d2_cypher::CypherConnectionManager { url: config.url.to_string() }; + r2d2::Pool::builder().max_size(config.pool_size).build(manager) + } +} + +#[cfg(feature = "redis_pool")] +impl Poolable for redis::Connection { + type Manager = r2d2_redis::RedisConnectionManager; + type Error = DbError; + + fn pool(config: DatabaseConfig) -> Result, Self::Error> { + let manager = r2d2_redis::RedisConnectionManager::new(config.url).map_err(DbError::Custom)?; + r2d2::Pool::builder().max_size(config.pool_size).build(manager) + .map_err(DbError::PoolError) + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + use rocket::{Config, config::{Environment, Value}}; + use super::{DatabaseConfigError, database_config}; + + #[test] + fn no_database_entry_in_config_returns_error() { + let config = Config::build(Environment::Development) + .finalize() + .unwrap(); + let database_config_result = database_config("dummy_db", &config); + + assert_eq!(Err(DatabaseConfigError::MissingTable), database_config_result); + } + + #[test] + fn no_matching_connection_returns_error() { + // Laboriously setup the config extras + let mut database_extra = BTreeMap::new(); + let mut connection_config = BTreeMap::new(); + connection_config.insert("url".to_string(), Value::from("dummy_db.sqlite")); + connection_config.insert("pool_size".to_string(), Value::from(10)); + database_extra.insert("dummy_db".to_string(), Value::from(connection_config)); + + let config = Config::build(Environment::Development) + .extra("databases", database_extra) + .finalize() + .unwrap(); + + let database_config_result = database_config("real_db", &config); + + assert_eq!(Err(DatabaseConfigError::MissingKey), database_config_result); + } + + #[test] + fn incorrectly_structured_config_returns_error() { + let mut database_extra = BTreeMap::new(); + let connection_config = vec!["url", "dummy_db.slqite"]; + database_extra.insert("dummy_db".to_string(), Value::from(connection_config)); + + let config = Config::build(Environment::Development) + .extra("databases", database_extra) + .finalize() + .unwrap(); + + let database_config_result = database_config("dummy_db", &config); + + assert_eq!(Err(DatabaseConfigError::MalformedConfiguration), database_config_result); + } + + #[test] + fn missing_connection_string_returns_error() { + let mut database_extra = BTreeMap::new(); + let connection_config: BTreeMap = BTreeMap::new(); + database_extra.insert("dummy_db", connection_config); + + let config = Config::build(Environment::Development) + .extra("databases", database_extra) + .finalize() + .unwrap(); + + let database_config_result = database_config("dummy_db", &config); + + assert_eq!(Err(DatabaseConfigError::MissingUrl), database_config_result); + } + + #[test] + fn invalid_connection_string_returns_error() { + let mut database_extra = BTreeMap::new(); + let mut connection_config = BTreeMap::new(); + connection_config.insert("url".to_string(), Value::from(42)); + database_extra.insert("dummy_db", connection_config); + + let config = Config::build(Environment::Development) + .extra("databases", database_extra) + .finalize() + .unwrap(); + + let database_config_result = database_config("dummy_db", &config); + + assert_eq!(Err(DatabaseConfigError::MalformedUrl), database_config_result); + } + + #[test] + fn negative_pool_size_returns_error() { + let mut database_extra = BTreeMap::new(); + let mut connection_config = BTreeMap::new(); + connection_config.insert("url".to_string(), Value::from("dummy_db.sqlite")); + connection_config.insert("pool_size".to_string(), Value::from(-1)); + database_extra.insert("dummy_db", connection_config); + + let config = Config::build(Environment::Development) + .extra("databases", database_extra) + .finalize() + .unwrap(); + + let database_config_result = database_config("dummy_db", &config); + + assert_eq!(Err(DatabaseConfigError::InvalidPoolSize(-1)), database_config_result); + } + + #[test] + fn pool_size_beyond_u32_max_returns_error() { + let mut database_extra = BTreeMap::new(); + let mut connection_config = BTreeMap::new(); + connection_config.insert("url".to_string(), Value::from("dummy_db.sqlite")); + connection_config.insert("pool_size".to_string(), Value::from(4294967296)); + database_extra.insert("dummy_db", connection_config); + + let config = Config::build(Environment::Development) + .extra("databases", database_extra) + .finalize() + .unwrap(); + + let database_config_result = database_config("dummy_db", &config); + + // The size of `0` is an overflow wrap-around + assert_eq!(Err(DatabaseConfigError::InvalidPoolSize(0)), database_config_result); + } + + #[test] + fn happy_path_database_config() { + let url = "dummy_db.sqlite"; + let pool_size = 10; + + let mut database_extra = BTreeMap::new(); + let mut connection_config = BTreeMap::new(); + connection_config.insert("url".to_string(), Value::from(url)); + connection_config.insert("pool_size".to_string(), Value::from(pool_size)); + database_extra.insert("dummy_db", connection_config); + + let config = Config::build(Environment::Development) + .extra("databases", database_extra) + .finalize() + .unwrap(); + + let database_config = database_config("dummy_db", &config).unwrap(); + + assert_eq!(url, database_config.url); + assert_eq!(pool_size, database_config.pool_size); + assert_eq!(0, database_config.extras.len()); + } + + #[test] + fn extras_do_not_contain_required_keys() { + let url = "dummy_db.sqlite"; + let pool_size = 10; + + let mut database_extra = BTreeMap::new(); + let mut connection_config = BTreeMap::new(); + connection_config.insert("url".to_string(), Value::from(url)); + connection_config.insert("pool_size".to_string(), Value::from(pool_size)); + database_extra.insert("dummy_db", connection_config); + + let config = Config::build(Environment::Development) + .extra("databases", database_extra) + .finalize() + .unwrap(); + + let database_config = database_config("dummy_db", &config).unwrap(); + + assert_eq!(url, database_config.url); + assert_eq!(pool_size, database_config.pool_size); + assert_eq!(false, database_config.extras.contains_key("url")); + assert_eq!(false, database_config.extras.contains_key("pool_size")); + } + + #[test] + fn extra_values_are_placed_in_extras_map() { + let url = "dummy_db.sqlite"; + let pool_size = 10; + let tls_cert = "certs.pem"; + let tls_key = "key.pem"; + + let mut database_extra = BTreeMap::new(); + let mut connection_config = BTreeMap::new(); + connection_config.insert("url".to_string(), Value::from(url)); + connection_config.insert("pool_size".to_string(), Value::from(pool_size)); + connection_config.insert("certs".to_string(), Value::from(tls_cert)); + connection_config.insert("key".to_string(), Value::from(tls_key)); + database_extra.insert("dummy_db", connection_config); + + let config = Config::build(Environment::Development) + .extra("databases", database_extra) + .finalize() + .unwrap(); + + let database_config = database_config("dummy_db", &config).unwrap(); + + assert_eq!(url, database_config.url); + assert_eq!(pool_size, database_config.pool_size); + assert_eq!(true, database_config.extras.contains_key("certs")); + assert_eq!(true, database_config.extras.contains_key("key")); + + println!("{:#?}", database_config); + } +} diff --git a/contrib/lib/src/lib.rs b/contrib/lib/src/lib.rs index bf88e46d..841660e1 100644 --- a/contrib/lib/src/lib.rs +++ b/contrib/lib/src/lib.rs @@ -1,5 +1,6 @@ #![feature(use_extern_macros)] #![feature(crate_visibility_modifier)] +#![feature(never_type)] // TODO: Version URLs. #![doc(html_root_url = "https://api.rocket.rs")] @@ -23,6 +24,7 @@ //! * [handlebars_templates](struct.Template.html) //! * [tera_templates](struct.Template.html) //! * [uuid](struct.Uuid.html) +//! * [database_pool](databases/index.html) //! //! The recommend way to include features from this crate via Cargo in your //! project is by adding a `[dependencies.rocket_contrib]` section to your @@ -86,3 +88,15 @@ pub use uuid::{Uuid, UuidParseError}; #[cfg(feature = "static_files")] pub mod static_files; + +#[cfg(feature = "database_pool")] +pub mod databases; + +#[cfg(feature = "database_pool_codegen")] +#[allow(unused_imports)] +#[macro_use] +extern crate rocket_contrib_codegen; + +#[cfg(feature = "database_pool_codegen")] +#[doc(hidden)] +pub use rocket_contrib_codegen::*; diff --git a/contrib/lib/tests/databases.rs b/contrib/lib/tests/databases.rs new file mode 100644 index 00000000..31d7ff89 --- /dev/null +++ b/contrib/lib/tests/databases.rs @@ -0,0 +1,15 @@ +#![feature(use_extern_macros)] + +extern crate rocket; +extern crate rocket_contrib; + +#[cfg(feature = "databases")] +mod databases_tests { + use rocket_contrib::databases::{database, diesel}; + + #[database("foo")] + struct TempStorage(diesel::SqliteConnection); + + #[database("bar")] + struct PrimaryDb(diesel::PgConnection); +} diff --git a/core/lib/src/logger.rs b/core/lib/src/logger.rs index a875f272..6abd4913 100644 --- a/core/lib/src/logger.rs +++ b/core/lib/src/logger.rs @@ -208,3 +208,10 @@ crate fn pop_max_level() { pub fn init(level: LoggingLevel) -> bool { try_init(level, true) } + +// This method exists as a shim for the log macros that need to be called from +// an end user's code. It was added as part of the work to support database +// connection pools via procedural macros. +pub fn log_err(msg: &str) { + error!("{}", msg); +} diff --git a/examples/todo/Cargo.toml b/examples/todo/Cargo.toml index 2565f0e7..f2eac1af 100644 --- a/examples/todo/Cargo.toml +++ b/examples/todo/Cargo.toml @@ -19,4 +19,4 @@ rand = "0.5" [dependencies.rocket_contrib] path = "../../contrib/lib" default_features = false -features = [ "tera_templates" ] +features = [ "tera_templates", "database_pool", "diesel_sqlite_pool" ] diff --git a/examples/todo/Rocket.toml b/examples/todo/Rocket.toml index f14a995d..cfd6c2d1 100644 --- a/examples/todo/Rocket.toml +++ b/examples/todo/Rocket.toml @@ -1,2 +1,5 @@ [global] template_dir = "static" + +[global.databases.sqlite_database] +url = "db/db.sqlite" diff --git a/examples/todo/bootstrap.sh b/examples/todo/bootstrap.sh index 7e3e7292..f9057f42 100755 --- a/examples/todo/bootstrap.sh +++ b/examples/todo/bootstrap.sh @@ -1,7 +1,7 @@ #! /usr/bin/env bash SCRIPT_PATH=$(cd "$(dirname "$0")" ; pwd -P) -DATABASE_URL="${SCRIPT_PATH}/db/db.sql" +DATABASE_URL="${SCRIPT_PATH}/db/db.sqlite" pushd "${SCRIPT_PATH}" > /dev/null # clear an existing database diff --git a/examples/todo/src/db.rs b/examples/todo/src/db.rs deleted file mode 100644 index 58bb2c37..00000000 --- a/examples/todo/src/db.rs +++ /dev/null @@ -1,40 +0,0 @@ -use std::ops::Deref; - -use diesel::sqlite::SqliteConnection; -use diesel::r2d2::{ConnectionManager, Pool, PooledConnection}; - -use rocket::http::Status; -use rocket::request::{self, FromRequest}; -use rocket::{Request, State, Outcome}; - -pub type SqlitePool = Pool>; - -pub const DATABASE_URL: &'static str = concat!(env!("CARGO_MANIFEST_DIR"), "/db/db.sql"); - -pub fn init_pool() -> SqlitePool { - let manager = ConnectionManager::::new(DATABASE_URL); - Pool::new(manager).expect("db pool") -} - -pub struct Conn(pub PooledConnection>); - -impl Deref for Conn { - type Target = SqliteConnection; - - #[inline(always)] - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl<'a, 'r> FromRequest<'a, 'r> for Conn { - type Error = (); - - fn from_request(request: &'a Request<'r>) -> request::Outcome { - let pool = request.guard::>()?; - match pool.get() { - Ok(conn) => Outcome::Success(Conn(conn)), - Err(_) => Outcome::Failure((Status::ServiceUnavailable, ())) - } - } -} diff --git a/examples/todo/src/main.rs b/examples/todo/src/main.rs index 21b17040..e697b1fa 100644 --- a/examples/todo/src/main.rs +++ b/examples/todo/src/main.rs @@ -1,4 +1,4 @@ -#![feature(plugin, decl_macro, const_fn)] +#![feature(plugin, decl_macro, use_extern_macros, custom_derive, const_fn)] #![plugin(rocket_codegen)] #[macro_use] extern crate rocket; @@ -8,31 +8,35 @@ extern crate rocket_contrib; mod static_files; mod task; -mod db; #[cfg(test)] mod tests; use rocket::Rocket; use rocket::request::{Form, FlashMessage}; use rocket::response::{Flash, Redirect}; use rocket_contrib::Template; +use rocket_contrib::databases::database; +use diesel::SqliteConnection; use task::{Task, Todo}; +#[database("sqlite_database")] +pub struct DbConn(SqliteConnection); + #[derive(Debug, Serialize)] struct Context<'a, 'b>{ msg: Option<(&'a str, &'b str)>, tasks: Vec } impl<'a, 'b> Context<'a, 'b> { - pub fn err(conn: &db::Conn, msg: &'a str) -> Context<'static, 'a> { + pub fn err(conn: &DbConn, msg: &'a str) -> Context<'static, 'a> { Context{msg: Some(("error", msg)), tasks: Task::all(conn)} } - pub fn raw(conn: &db::Conn, msg: Option<(&'a str, &'b str)>) -> Context<'a, 'b> { + pub fn raw(conn: &DbConn, msg: Option<(&'a str, &'b str)>) -> Context<'a, 'b> { Context{msg: msg, tasks: Task::all(conn)} } } #[post("/", data = "")] -fn new(todo_form: Form, conn: db::Conn) -> Flash { +fn new(todo_form: Form, conn: DbConn) -> Flash { let todo = todo_form.into_inner(); if todo.description.is_empty() { Flash::error(Redirect::to("/"), "Description cannot be empty.") @@ -44,7 +48,7 @@ fn new(todo_form: Form, conn: db::Conn) -> Flash { } #[put("/")] -fn toggle(id: i32, conn: db::Conn) -> Result { +fn toggle(id: i32, conn: DbConn) -> Result { if Task::toggle_with_id(id, &conn) { Ok(Redirect::to("/")) } else { @@ -53,7 +57,7 @@ fn toggle(id: i32, conn: db::Conn) -> Result { } #[delete("/")] -fn delete(id: i32, conn: db::Conn) -> Result, Template> { +fn delete(id: i32, conn: DbConn) -> Result, Template> { if Task::delete_with_id(id, &conn) { Ok(Flash::success(Redirect::to("/"), "Todo was deleted.")) } else { @@ -62,27 +66,25 @@ fn delete(id: i32, conn: db::Conn) -> Result, Template> { } #[get("/")] -fn index(msg: Option, conn: db::Conn) -> Template { +fn index(msg: Option, conn: DbConn) -> Template { Template::render("index", &match msg { Some(ref msg) => Context::raw(&conn, Some((msg.name(), msg.msg()))), None => Context::raw(&conn, None), }) } -fn rocket() -> (Rocket, Option) { - let pool = db::init_pool(); - let conn = if cfg!(test) { - Some(db::Conn(pool.get().expect("database connection for testing"))) - } else { - None - }; - +fn rocket() -> (Rocket, Option) { let rocket = rocket::ignite() - .manage(pool) + .attach(DbConn::fairing()) .mount("/", routes![index, static_files::all]) .mount("/todo", routes![new, toggle, delete]) .attach(Template::fairing()); + let conn = match cfg!(test) { + true => DbConn::get_one(&rocket), + false => None, + }; + (rocket, conn) } diff --git a/examples/todo/src/task.rs b/examples/todo/src/task.rs index d86fc0a0..a1d2bbc9 100644 --- a/examples/todo/src/task.rs +++ b/examples/todo/src/task.rs @@ -1,6 +1,4 @@ -use diesel; -use diesel::prelude::*; -use diesel::sqlite::SqliteConnection; +use diesel::{self, prelude::*}; mod schema { table! {