use proc_macro::TokenStream; use devise::{DeriveGenerator, FromMeta, MapperBuild, Support, ValidatorBuild}; use devise::proc_macro2_diagnostics::SpanDiagnosticExt; use devise::syn::{self, spanned::Spanned}; const ONE_DATABASE_ATTR: &str = "missing `#[database(\"name\")]` attribute"; const ONE_UNNAMED_FIELD: &str = "struct must have exactly one unnamed field"; #[derive(Debug, FromMeta)] struct DatabaseAttribute { #[meta(naked)] name: String, } pub fn derive_database(input: TokenStream) -> TokenStream { DeriveGenerator::build_for(input, quote!(impl rocket_db_pools::Database)) .support(Support::TupleStruct) .validator(ValidatorBuild::new() .struct_validate(|_, s| { if s.fields.len() == 1 { Ok(()) } else { Err(s.fields.span().error(ONE_UNNAMED_FIELD)) } }) ) .outer_mapper(MapperBuild::new() .struct_map(|_, s| { let decorated_type = &s.ident; let pool_type = match &s.fields { syn::Fields::Unnamed(f) => &f.unnamed[0].ty, _ => unreachable!("Support::TupleStruct"), }; quote_spanned! { s.span() => impl From<#pool_type> for #decorated_type { fn from(pool: #pool_type) -> Self { Self(pool) } } impl std::ops::Deref for #decorated_type { type Target = #pool_type; fn deref(&self) -> &Self::Target { &self.0 } } impl std::ops::DerefMut for #decorated_type { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } #[rocket::async_trait] impl<'r> rocket::request::FromRequest<'r> for &'r #decorated_type { type Error = (); async fn from_request( req: &'r rocket::request::Request<'_> ) -> rocket::request::Outcome { match #decorated_type::fetch(req.rocket()) { Some(db) => rocket::outcome::Outcome::Success(db), None => rocket::outcome::Outcome::Failure(( rocket::http::Status::InternalServerError, ())) } } } impl rocket::Sentinel for &#decorated_type { fn abort(rocket: &rocket::Rocket) -> bool { #decorated_type::fetch(rocket).is_none() } } } }) ) .outer_mapper(quote!(#[rocket::async_trait])) .inner_mapper(MapperBuild::new() .try_struct_map(|_, s| { let db_name = DatabaseAttribute::one_from_attrs("database", &s.attrs)? .map(|attr| attr.name) .ok_or_else(|| s.span().error(ONE_DATABASE_ATTR))?; let fairing_name = format!("'{}' Database Pool", db_name); let pool_type = match &s.fields { syn::Fields::Unnamed(f) => &f.unnamed[0].ty, _ => unreachable!("Support::TupleStruct"), }; Ok(quote_spanned! { s.span() => type Pool = #pool_type; const NAME: &'static str = #db_name; fn init() -> rocket_db_pools::Initializer { rocket_db_pools::Initializer::with_name(#fairing_name) } }) }) ) .to_tokens() }