2018-08-15 09:07:17 +00:00
|
|
|
use proc_macro::TokenStream;
|
2018-10-29 04:09:04 +00:00
|
|
|
use devise::{Spanned, Result};
|
2019-06-13 02:17:59 +00:00
|
|
|
use crate::syn::{DataStruct, Fields, Data, Type, LitStr, DeriveInput, Ident, Visibility};
|
2018-07-21 22:11:08 +00:00
|
|
|
|
|
|
|
#[derive(Debug)]
|
|
|
|
struct DatabaseInvocation {
|
|
|
|
/// The name of the structure on which `#[database(..)] struct This(..)` was invoked.
|
|
|
|
type_name: Ident,
|
|
|
|
/// The visibility of the structure on which `#[database(..)] struct This(..)` was invoked.
|
|
|
|
visibility: Visibility,
|
|
|
|
/// The database name as passed in via #[database('database name')].
|
|
|
|
db_name: String,
|
|
|
|
/// The entire structure that the `database` attribute was called on.
|
|
|
|
structure: DataStruct,
|
|
|
|
/// The type inside the structure: struct MyDb(ThisType).
|
|
|
|
connection_type: Type,
|
|
|
|
}
|
|
|
|
|
|
|
|
const EXAMPLE: &str = "example: `struct MyDatabase(diesel::SqliteConnection);`";
|
|
|
|
const ONLY_ON_STRUCTS_MSG: &str = "`database` attribute can only be used on structs";
|
2018-08-15 09:07:17 +00:00
|
|
|
const ONLY_UNNAMED_FIELDS: &str = "`database` attribute can only be applied to \
|
|
|
|
structs with exactly one unnamed field";
|
2018-12-12 08:00:10 +00:00
|
|
|
const NO_GENERIC_STRUCTS: &str = "`database` attribute cannot be applied to structs \
|
|
|
|
with generics";
|
2018-07-21 22:11:08 +00:00
|
|
|
|
|
|
|
fn parse_invocation(attr: TokenStream, input: TokenStream) -> Result<DatabaseInvocation> {
|
2019-06-13 02:17:59 +00:00
|
|
|
let attr_stream2 = crate::proc_macro2::TokenStream::from(attr);
|
2018-07-21 22:11:08 +00:00
|
|
|
let attr_span = attr_stream2.span();
|
2019-06-13 02:17:59 +00:00
|
|
|
let string_lit = crate::syn::parse2::<LitStr>(attr_stream2)
|
2018-07-21 22:11:08 +00:00
|
|
|
.map_err(|_| attr_span.error("expected string literal"))?;
|
|
|
|
|
2019-06-13 02:17:59 +00:00
|
|
|
let input = crate::syn::parse::<DeriveInput>(input).unwrap();
|
2018-07-21 22:11:08 +00:00
|
|
|
if !input.generics.params.is_empty() {
|
2018-12-12 08:00:10 +00:00
|
|
|
return Err(input.generics.span().error(NO_GENERIC_STRUCTS));
|
2018-07-21 22:11:08 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
let structure = match input.data {
|
|
|
|
Data::Struct(s) => s,
|
|
|
|
_ => return Err(input.span().error(ONLY_ON_STRUCTS_MSG))
|
|
|
|
};
|
|
|
|
|
|
|
|
let inner_type = match structure.fields {
|
|
|
|
Fields::Unnamed(ref fields) if fields.unnamed.len() == 1 => {
|
|
|
|
let first = fields.unnamed.first().expect("checked length");
|
2019-09-05 22:43:57 +00:00
|
|
|
first.ty.clone()
|
2018-07-21 22:11:08 +00:00
|
|
|
}
|
|
|
|
_ => return Err(structure.fields.span().error(ONLY_UNNAMED_FIELDS).help(EXAMPLE))
|
|
|
|
};
|
|
|
|
|
|
|
|
Ok(DatabaseInvocation {
|
|
|
|
type_name: input.ident,
|
|
|
|
visibility: input.vis,
|
|
|
|
db_name: string_lit.value(),
|
|
|
|
structure: structure,
|
|
|
|
connection_type: inner_type,
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
2018-12-12 08:00:10 +00:00
|
|
|
#[allow(non_snake_case)]
|
2018-07-21 22:11:08 +00:00
|
|
|
pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result<TokenStream> {
|
|
|
|
let invocation = parse_invocation(attr, input)?;
|
|
|
|
|
2018-08-15 09:07:17 +00:00
|
|
|
// Store everything we're going to need to generate code.
|
2018-12-12 08:00:10 +00:00
|
|
|
let conn_type = &invocation.connection_type;
|
2018-08-15 09:07:17 +00:00
|
|
|
let name = &invocation.db_name;
|
2018-12-12 08:00:10 +00:00
|
|
|
let guard_type = &invocation.type_name;
|
2018-08-15 09:07:17 +00:00
|
|
|
let vis = &invocation.visibility;
|
2018-12-12 08:00:10 +00:00
|
|
|
let pool_type = Ident::new(&format!("{}Pool", guard_type), guard_type.span());
|
2018-08-15 09:07:17 +00:00
|
|
|
let fairing_name = format!("'{}' Database Pool", name);
|
2018-12-12 08:00:10 +00:00
|
|
|
let span = conn_type.span().into();
|
2018-08-15 09:07:17 +00:00
|
|
|
|
|
|
|
// A few useful paths.
|
2018-12-12 08:00:10 +00:00
|
|
|
let databases = quote_spanned!(span => ::rocket_contrib::databases);
|
|
|
|
let Poolable = quote_spanned!(span => #databases::Poolable);
|
|
|
|
let r2d2 = quote_spanned!(span => #databases::r2d2);
|
2019-12-11 00:34:23 +00:00
|
|
|
let spawn_blocking = quote_spanned!(span => #databases::spawn_blocking);
|
2018-08-15 09:07:17 +00:00
|
|
|
let request = quote!(::rocket::request);
|
2018-07-21 22:11:08 +00:00
|
|
|
|
2018-12-12 08:00:10 +00:00
|
|
|
let generated_types = quote_spanned! { span =>
|
2018-08-15 09:07:17 +00:00
|
|
|
/// The request guard type.
|
2018-12-12 08:00:10 +00:00
|
|
|
#vis struct #guard_type(pub #r2d2::PooledConnection<<#conn_type as #Poolable>::Manager>);
|
2018-08-15 09:07:17 +00:00
|
|
|
|
|
|
|
/// The pool type.
|
2018-12-12 08:00:10 +00:00
|
|
|
#vis struct #pool_type(#r2d2::Pool<<#conn_type as #Poolable>::Manager>);
|
|
|
|
};
|
|
|
|
|
|
|
|
Ok(quote! {
|
|
|
|
#generated_types
|
2018-07-21 22:11:08 +00:00
|
|
|
|
2018-12-12 08:00:10 +00:00
|
|
|
impl #guard_type {
|
2018-08-15 09:07:17 +00:00
|
|
|
/// Returns a fairing that initializes the associated database
|
|
|
|
/// connection pool.
|
2018-07-21 22:11:08 +00:00
|
|
|
pub fn fairing() -> impl ::rocket::fairing::Fairing {
|
2018-08-15 09:07:17 +00:00
|
|
|
use #databases::Poolable;
|
2018-07-21 22:11:08 +00:00
|
|
|
|
2020-06-14 15:57:53 +00:00
|
|
|
::rocket::fairing::AdHoc::on_attach(#fairing_name, |mut rocket| async {
|
2020-06-14 15:57:54 +00:00
|
|
|
let pool = #databases::database_config(#name, rocket.config().await)
|
2018-12-12 08:00:10 +00:00
|
|
|
.map(<#conn_type>::pool);
|
2018-07-21 22:11:08 +00:00
|
|
|
|
|
|
|
match pool {
|
|
|
|
Ok(Ok(p)) => Ok(rocket.manage(#pool_type(p))),
|
|
|
|
Err(config_error) => {
|
2018-11-12 21:08:39 +00:00
|
|
|
::rocket::logger::error(
|
2018-08-15 09:07:17 +00:00
|
|
|
&format!("Database configuration failure: '{}'", #name));
|
2018-11-12 21:08:39 +00:00
|
|
|
::rocket::logger::error_(&format!("{}", config_error));
|
2018-07-21 22:11:08 +00:00
|
|
|
Err(rocket)
|
|
|
|
},
|
|
|
|
Ok(Err(pool_error)) => {
|
2018-11-12 21:08:39 +00:00
|
|
|
::rocket::logger::error(
|
2018-08-15 09:07:17 +00:00
|
|
|
&format!("Failed to initialize pool for '{}'", #name));
|
2018-11-12 21:08:39 +00:00
|
|
|
::rocket::logger::error_(&format!("{:?}", pool_error));
|
2018-07-21 22:11:08 +00:00
|
|
|
Err(rocket)
|
|
|
|
},
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
2018-08-15 09:07:17 +00:00
|
|
|
/// Retrieves a connection of type `Self` from the `rocket`
|
|
|
|
/// instance. Returns `Some` as long as `Self::fairing()` has been
|
|
|
|
/// attached and there is at least one connection in the pool.
|
2020-06-14 15:57:51 +00:00
|
|
|
pub fn get_one(manifest: &::rocket::Manifest) -> Option<Self> {
|
|
|
|
manifest.state::<#pool_type>()
|
2018-07-21 22:11:08 +00:00
|
|
|
.and_then(|pool| pool.0.get().ok())
|
2018-12-12 08:00:10 +00:00
|
|
|
.map(#guard_type)
|
2018-07-21 22:11:08 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-12-12 08:00:10 +00:00
|
|
|
impl ::std::ops::Deref for #guard_type {
|
|
|
|
type Target = #conn_type;
|
2018-07-21 22:11:08 +00:00
|
|
|
|
|
|
|
#[inline(always)]
|
|
|
|
fn deref(&self) -> &Self::Target {
|
|
|
|
&self.0
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-12-12 08:00:10 +00:00
|
|
|
impl ::std::ops::DerefMut for #guard_type {
|
2018-12-11 21:57:23 +00:00
|
|
|
#[inline(always)]
|
|
|
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
|
|
|
&mut self.0
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-01-31 09:34:15 +00:00
|
|
|
#[::rocket::async_trait]
|
|
|
|
impl<'a, 'r> #request::FromRequest<'a, 'r> for #guard_type {
|
2018-07-21 22:11:08 +00:00
|
|
|
type Error = ();
|
|
|
|
|
2020-01-31 09:34:15 +00:00
|
|
|
async fn from_request(request: &'a #request::Request<'r>) -> #request::Outcome<Self, ()> {
|
2018-08-15 09:07:17 +00:00
|
|
|
use ::rocket::{Outcome, http::Status};
|
2020-01-31 09:34:15 +00:00
|
|
|
|
|
|
|
let guard = request.guard::<::rocket::State<'_, #pool_type>>();
|
|
|
|
let pool = ::rocket::try_outcome!(guard.await).0.clone();
|
|
|
|
|
|
|
|
#spawn_blocking(move || {
|
|
|
|
match pool.get() {
|
|
|
|
Ok(conn) => Outcome::Success(#guard_type(conn)),
|
|
|
|
Err(_) => Outcome::Failure((Status::ServiceUnavailable, ())),
|
|
|
|
}
|
|
|
|
}).await.expect("failed to spawn a blocking task to get a pooled connection")
|
2018-07-21 22:11:08 +00:00
|
|
|
}
|
|
|
|
}
|
2019-12-11 00:34:23 +00:00
|
|
|
|
|
|
|
// TODO.async: What about spawn_blocking on drop?
|
2018-08-15 09:07:17 +00:00
|
|
|
}.into())
|
2018-07-21 22:11:08 +00:00
|
|
|
}
|