Add fairing retrieval methods to 'Rocket'.

Introduces four new methods:

  * `Rocket::fairing::<F>()`
  * `Rocket::fairing_mut::<F>()`
  * `Rocket::fairings::<F>()`
  * `Rocket::fairings_mut::<F>()`

These methods allow retrieving references to fairings of type `F` from
an instance of `Rocket`. The `fairing` and `fairing_mut` methods return
a (mutable) reference to the first attached fairing of type `F`, while
the `fairings` and `fairings_mut` methods return an iterator over
(mutable) references to all attached fairings of type `F`.

Co-authored-by: Matthew Pomes <matthew.pomes@pm.me>
This commit is contained in:
Sergio Benitez 2024-08-21 01:51:24 -07:00 committed by Sergio Benitez
parent dbeba45b36
commit d3323391ab
4 changed files with 284 additions and 23 deletions

View File

@ -1,5 +1,3 @@
use std::collections::HashSet;
use crate::{Rocket, Request, Response, Data, Build, Orbit}; use crate::{Rocket, Request, Response, Data, Build, Orbit};
use crate::fairing::{Fairing, Info, Kind}; use crate::fairing::{Fairing, Info, Kind};
@ -21,14 +19,15 @@ pub struct Fairings {
macro_rules! iter { macro_rules! iter {
($_self:ident . $kind:ident) => ({ ($_self:ident . $kind:ident) => ({
iter!($_self, $_self.$kind.iter()).map(|v| v.1) iter!($_self, $_self.$kind.iter().copied()).map(|v| v.1)
}); });
($_self:ident, $indices:expr) => ({ ($_self:ident, $indices:expr) => ({
let all_fairings = &$_self.all_fairings; let all_fairings = &$_self.all_fairings;
$indices.filter_map(move |i| { $indices.filter_map(move |i| {
debug_assert!(all_fairings.get(*i).is_some()); let i = i.clone();
let f = all_fairings.get(*i).map(|f| &**f)?; debug_assert!(all_fairings.get(i).is_some());
Some((*i, f)) let f = all_fairings.get(i).map(|f| &**f)?;
Some((i, f))
}) })
}) })
} }
@ -47,10 +46,19 @@ impl Fairings {
.chain(self.shutdown.iter()) .chain(self.shutdown.iter())
} }
pub fn unique_active(&self) -> impl Iterator<Item = usize> {
let mut bitmap = vec![false; self.all_fairings.len()];
for i in self.active() {
bitmap.get_mut(*i).map(|active| *active = true);
}
bitmap.into_iter()
.enumerate()
.filter_map(|(i, active)| active.then_some(i))
}
pub fn unique_set(&self) -> Vec<&dyn Fairing> { pub fn unique_set(&self) -> Vec<&dyn Fairing> {
iter!(self, self.active().collect::<HashSet<_>>().into_iter()) iter!(self, self.unique_active()).map(|v| v.1).collect()
.map(|v| v.1)
.collect()
} }
pub fn add(&mut self, fairing: Box<dyn Fairing>) { pub fn add(&mut self, fairing: Box<dyn Fairing>) {
@ -83,7 +91,7 @@ impl Fairings {
}; };
// Collect all of the active duplicates. // Collect all of the active duplicates.
let mut dups: Vec<usize> = iter!(self, self.active()) let mut dups: Vec<usize> = iter!(self, self.unique_active())
.filter(|(_, f)| f.type_id() == this.type_id()) .filter(|(_, f)| f.type_id() == this.type_id())
.map(|(i, _)| i) .map(|(i, _)| i)
.collect(); .collect();
@ -167,11 +175,32 @@ impl Fairings {
} }
pub fn audit(&self) -> Result<(), &[Info]> { pub fn audit(&self) -> Result<(), &[Info]> {
match self.failures.is_empty() { match &self.failures[..] {
true => Ok(()), [] => Ok(()),
false => Err(&self.failures) failures => Err(failures)
} }
} }
pub fn filter<F: Fairing>(&self) -> impl Iterator<Item = &F> {
iter!(self, self.unique_active())
.filter_map(|v| v.1.downcast_ref::<F>())
}
pub fn filter_mut<F: Fairing>(&mut self) -> impl Iterator<Item = &mut F> {
let mut bitmap = vec![false; self.all_fairings.len()];
for &i in self.active() {
let is_target = self.all_fairings.get(i)
.and_then(|f| f.downcast_ref::<F>())
.is_some();
bitmap.get_mut(i).map(|target| *target = is_target);
}
self.all_fairings.iter_mut()
.enumerate()
.filter(move |(i, _)| *bitmap.get(*i).unwrap_or(&false))
.filter_map(|(_, f)| f.downcast_mut::<F>())
}
} }
impl std::fmt::Debug for Fairings { impl std::fmt::Debug for Fairings {

View File

@ -425,7 +425,7 @@ pub type Result<T = Rocket<Build>, E = Rocket<Build>> = std::result::Result<T, E
/// ///
/// [request-local state]: https://rocket.rs/master/guide/state/#request-local-state /// [request-local state]: https://rocket.rs/master/guide/state/#request-local-state
#[crate::async_trait] #[crate::async_trait]
pub trait Fairing: Send + Sync + Any + 'static { pub trait Fairing: Send + Sync + AsAny + 'static {
/// Returns an [`Info`] structure containing the `name` and [`Kind`] of this /// Returns an [`Info`] structure containing the `name` and [`Kind`] of this
/// fairing. The `name` can be any arbitrary string. `Kind` must be an `or`d /// fairing. The `name` can be any arbitrary string. `Kind` must be an `or`d
/// set of `Kind` variants. /// set of `Kind` variants.
@ -533,6 +533,11 @@ pub trait Fairing: Send + Sync + Any + 'static {
async fn on_shutdown(&self, _rocket: &Rocket<Orbit>) { } async fn on_shutdown(&self, _rocket: &Rocket<Orbit>) { }
} }
pub trait AsAny: Any {
fn as_any_ref(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
}
#[crate::async_trait] #[crate::async_trait]
impl<T: Fairing + ?Sized> Fairing for std::sync::Arc<T> { impl<T: Fairing + ?Sized> Fairing for std::sync::Arc<T> {
#[inline] #[inline]
@ -565,3 +570,18 @@ impl<T: Fairing + ?Sized> Fairing for std::sync::Arc<T> {
(self as &T).on_shutdown(rocket).await (self as &T).on_shutdown(rocket).await
} }
} }
impl<T: Any> AsAny for T {
fn as_any_ref(&self) -> &dyn Any { self }
fn as_any_mut(&mut self) -> &mut dyn Any { self }
}
impl dyn Fairing {
fn downcast_ref<T: Any>(&self) -> Option<&T> {
self.as_any_ref().downcast_ref()
}
fn downcast_mut<T: Any>(&mut self) -> Option<&mut T> {
self.as_any_mut().downcast_mut()
}
}

View File

@ -14,7 +14,8 @@ mod private {
#[doc(hidden)] #[doc(hidden)]
pub trait Stateful: private::Sealed { pub trait Stateful: private::Sealed {
fn into_state(self) -> State; fn into_state(self) -> State;
fn as_state_ref(&self) -> StateRef<'_>; fn as_ref(&self) -> StateRef<'_>;
fn as_mut(&mut self) -> StateRefMut<'_>;
} }
/// A marker trait for Rocket's launch phases. /// A marker trait for Rocket's launch phases.
@ -48,7 +49,8 @@ macro_rules! phase {
impl Stateful for $S { impl Stateful for $S {
fn into_state(self) -> State { State::$P(self) } fn into_state(self) -> State { State::$P(self) }
fn as_state_ref(&self) -> StateRef<'_> { StateRef::$P(self) } fn as_ref(&self) -> StateRef<'_> { StateRef::$P(self) }
fn as_mut(&mut self) -> StateRefMut<'_> { StateRefMut::$P(self) }
} }
#[doc(hidden)] #[doc(hidden)]
@ -70,6 +72,9 @@ macro_rules! phases {
#[doc(hidden)] #[doc(hidden)]
pub enum StateRef<'a> { $($P(&'a $S)),* } pub enum StateRef<'a> { $($P(&'a $S)),* }
#[doc(hidden)]
pub enum StateRefMut<'a> { $($P(&'a mut $S)),* }
$(phase!($(#[$o])* $P ($(#[$i])* $S) { $($fields)* });)* $(phase!($(#[$o])* $P ($(#[$i])* $S) { $($fields)* });)*
) )
} }

View File

@ -17,7 +17,7 @@ use crate::listener::{Bind, DefaultListener, Endpoint, Listener};
use crate::router::Router; use crate::router::Router;
use crate::fairing::{Fairing, Fairings}; use crate::fairing::{Fairing, Fairings};
use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting}; use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting};
use crate::phase::{Stateful, StateRef, State}; use crate::phase::{Stateful, StateRef, StateRefMut, State};
use crate::http::uri::Origin; use crate::http::uri::Origin;
use crate::http::ext::IntoOwned; use crate::http::ext::IntoOwned;
use crate::error::{Error, ErrorKind}; use crate::error::{Error, ErrorKind};
@ -464,8 +464,11 @@ impl Rocket<Build> {
/// Attaches a fairing to this instance of Rocket. No fairings are eagerly /// Attaches a fairing to this instance of Rocket. No fairings are eagerly
/// executed; fairings are executed at their appropriate time. /// executed; fairings are executed at their appropriate time.
/// ///
/// If the attached fairing is _fungible_ and a fairing of the same name /// If the attached fairing is a [singleton] and a fairing of the same type
/// already exists, this fairing replaces it. /// has already been attached, this fairing replaces it. Otherwise the
/// fairing gets attached without replacing any existing fairing.
///
/// [singleton]: crate::fairing::Fairing#singletons
/// ///
/// # Example /// # Example
/// ///
@ -835,7 +838,7 @@ impl<P: Phase> Rocket<P> {
/// assert!(rocket.routes().any(|r| r.uri == "/hi/hello")); /// assert!(rocket.routes().any(|r| r.uri == "/hi/hello"));
/// ``` /// ```
pub fn routes(&self) -> impl Iterator<Item = &Route> { pub fn routes(&self) -> impl Iterator<Item = &Route> {
match self.0.as_state_ref() { match self.0.as_ref() {
StateRef::Build(p) => Either::Left(p.routes.iter()), StateRef::Build(p) => Either::Left(p.routes.iter()),
StateRef::Ignite(p) => Either::Right(p.router.routes()), StateRef::Ignite(p) => Either::Right(p.router.routes()),
StateRef::Orbit(p) => Either::Right(p.router.routes()), StateRef::Orbit(p) => Either::Right(p.router.routes()),
@ -866,7 +869,7 @@ impl<P: Phase> Rocket<P> {
/// assert!(rocket.catchers().any(|c| c.code == None && c.base() == "/")); /// assert!(rocket.catchers().any(|c| c.code == None && c.base() == "/"));
/// ``` /// ```
pub fn catchers(&self) -> impl Iterator<Item = &Catcher> { pub fn catchers(&self) -> impl Iterator<Item = &Catcher> {
match self.0.as_state_ref() { match self.0.as_ref() {
StateRef::Build(p) => Either::Left(p.catchers.iter()), StateRef::Build(p) => Either::Left(p.catchers.iter()),
StateRef::Ignite(p) => Either::Right(p.router.catchers()), StateRef::Ignite(p) => Either::Right(p.router.catchers()),
StateRef::Orbit(p) => Either::Right(p.router.catchers()), StateRef::Orbit(p) => Either::Right(p.router.catchers()),
@ -886,13 +889,217 @@ impl<P: Phase> Rocket<P> {
/// assert_eq!(rocket.state::<MyState>().unwrap(), &MyState("hello!")); /// assert_eq!(rocket.state::<MyState>().unwrap(), &MyState("hello!"));
/// ``` /// ```
pub fn state<T: Send + Sync + 'static>(&self) -> Option<&T> { pub fn state<T: Send + Sync + 'static>(&self) -> Option<&T> {
match self.0.as_state_ref() { match self.0.as_ref() {
StateRef::Build(p) => p.state.try_get(), StateRef::Build(p) => p.state.try_get(),
StateRef::Ignite(p) => p.state.try_get(), StateRef::Ignite(p) => p.state.try_get(),
StateRef::Orbit(p) => p.state.try_get(), StateRef::Orbit(p) => p.state.try_get(),
} }
} }
/// Returns a reference to the first fairing of type `F` if it is attached.
/// Otherwise, returns `None`.
///
/// To retrieve a _mutable_ reference to fairing `F`, use
/// [`Rocket::fairing_mut()`] instead.
///
/// # Example
///
/// ```rust
/// # use rocket::{Rocket, Request, Data, Response, Build, Orbit};
/// # use rocket::fairing::{self, Fairing, Info, Kind};
/// #
/// # #[rocket::async_trait]
/// # impl Fairing for MyFairing {
/// # fn info(&self) -> Info {
/// # Info { name: "", kind: Kind::Ignite }
/// # }
/// # }
/// #
/// # #[rocket::async_trait]
/// # impl Fairing for MySingletonFairing {
/// # fn info(&self) -> Info {
/// # Info { name: "", kind: Kind::Ignite | Kind::Singleton }
/// # }
/// # }
/// // A regular, non-singleton fairing.
/// struct MyFairing(&'static str);
///
/// // A singleton fairing.
/// struct MySingletonFairing(&'static str);
///
/// // fairing is not attached, returns `None`
/// let rocket = rocket::build();
/// assert!(rocket.fairing::<MyFairing>().is_none());
/// assert!(rocket.fairing::<MySingletonFairing>().is_none());
///
/// // attach fairing, now returns `Some`
/// let rocket = rocket.attach(MyFairing("some state"));
/// assert!(rocket.fairing::<MyFairing>().is_some());
/// assert_eq!(rocket.fairing::<MyFairing>().unwrap().0, "some state");
///
/// // it returns the first fairing of a given type only
/// let rocket = rocket.attach(MyFairing("other state"));
/// assert_eq!(rocket.fairing::<MyFairing>().unwrap().0, "some state");
///
/// // attach fairing, now returns `Some`
/// let rocket = rocket.attach(MySingletonFairing("first"));
/// assert_eq!(rocket.fairing::<MySingletonFairing>().unwrap().0, "first");
///
/// // recall that new singletons replace existing attached singletons
/// let rocket = rocket.attach(MySingletonFairing("second"));
/// assert_eq!(rocket.fairing::<MySingletonFairing>().unwrap().0, "second");
/// ```
pub fn fairing<F: Fairing>(&self) -> Option<&F> {
match self.0.as_ref() {
StateRef::Build(p) => p.fairings.filter::<F>().next(),
StateRef::Ignite(p) => p.fairings.filter::<F>().next(),
StateRef::Orbit(p) => p.fairings.filter::<F>().next(),
}
}
/// Returns an iterator over all attached fairings of type `F`, if any.
///
/// # Example
///
/// ```rust
/// # use rocket::{Rocket, Request, Data, Response, Build, Orbit};
/// # use rocket::fairing::{self, Fairing, Info, Kind};
/// #
/// # #[rocket::async_trait]
/// # impl Fairing for MyFairing {
/// # fn info(&self) -> Info {
/// # Info { name: "", kind: Kind::Ignite }
/// # }
/// # }
/// #
/// # #[rocket::async_trait]
/// # impl Fairing for MySingletonFairing {
/// # fn info(&self) -> Info {
/// # Info { name: "", kind: Kind::Ignite | Kind::Singleton }
/// # }
/// # }
/// // A regular, non-singleton fairing.
/// struct MyFairing(&'static str);
///
/// // A singleton fairing.
/// struct MySingletonFairing(&'static str);
///
/// let rocket = rocket::build();
/// assert_eq!(rocket.fairings::<MyFairing>().count(), 0);
/// assert_eq!(rocket.fairings::<MySingletonFairing>().count(), 0);
///
/// let rocket = rocket.attach(MyFairing("some state"))
/// .attach(MySingletonFairing("first"))
/// .attach(MySingletonFairing("second"))
/// .attach(MyFairing("other state"))
/// .attach(MySingletonFairing("third"));
///
/// let my_fairings: Vec<_> = rocket.fairings::<MyFairing>().collect();
/// assert_eq!(my_fairings.len(), 2);
/// assert_eq!(my_fairings[0].0, "some state");
/// assert_eq!(my_fairings[1].0, "other state");
///
/// let my_singleton: Vec<_> = rocket.fairings::<MySingletonFairing>().collect();
/// assert_eq!(my_singleton.len(), 1);
/// assert_eq!(my_singleton[0].0, "third");
/// ```
pub fn fairings<F: Fairing>(&self) -> impl Iterator<Item = &F> {
match self.0.as_ref() {
StateRef::Build(p) => Either::Left(p.fairings.filter::<F>()),
StateRef::Ignite(p) => Either::Right(p.fairings.filter::<F>()),
StateRef::Orbit(p) => Either::Right(p.fairings.filter::<F>()),
}
}
/// Returns a mutable reference to the first fairing of type `F` if it is
/// attached. Otherwise, returns `None`.
///
/// # Example
///
/// ```rust
/// # use rocket::{Rocket, Request, Data, Response, Build, Orbit};
/// # use rocket::fairing::{self, Fairing, Info, Kind};
/// #
/// # #[rocket::async_trait]
/// # impl Fairing for MyFairing {
/// # fn info(&self) -> Info {
/// # Info { name: "", kind: Kind::Ignite }
/// # }
/// # }
/// // A regular, non-singleton fairing.
/// struct MyFairing(&'static str);
///
/// // fairing is not attached, returns `None`
/// let mut rocket = rocket::build();
/// assert!(rocket.fairing_mut::<MyFairing>().is_none());
///
/// // attach fairing, now returns `Some`
/// let mut rocket = rocket.attach(MyFairing("some state"));
/// assert!(rocket.fairing_mut::<MyFairing>().is_some());
/// assert_eq!(rocket.fairing_mut::<MyFairing>().unwrap().0, "some state");
///
/// // we can modify the fairing
/// rocket.fairing_mut::<MyFairing>().unwrap().0 = "other state";
/// assert_eq!(rocket.fairing_mut::<MyFairing>().unwrap().0, "other state");
///
/// // it returns the first fairing of a given type only
/// let mut rocket = rocket.attach(MyFairing("yet more state"));
/// assert_eq!(rocket.fairing_mut::<MyFairing>().unwrap().0, "other state");
/// ```
pub fn fairing_mut<F: Fairing>(&mut self) -> Option<&mut F> {
match self.0.as_mut() {
StateRefMut::Build(p) => p.fairings.filter_mut::<F>().next(),
StateRefMut::Ignite(p) => p.fairings.filter_mut::<F>().next(),
StateRefMut::Orbit(p) => p.fairings.filter_mut::<F>().next(),
}
}
/// Returns an iterator of mutable references to all attached fairings of
/// type `F`, if any.
///
/// # Example
///
/// ```rust
/// # use rocket::{Rocket, Request, Data, Response, Build, Orbit};
/// # use rocket::fairing::{self, Fairing, Info, Kind};
/// #
/// # #[rocket::async_trait]
/// # impl Fairing for MyFairing {
/// # fn info(&self) -> Info {
/// # Info { name: "", kind: Kind::Ignite }
/// # }
/// # }
/// // A regular, non-singleton fairing.
/// struct MyFairing(&'static str);
///
/// let mut rocket = rocket::build()
/// .attach(MyFairing("some state"))
/// .attach(MyFairing("other state"))
/// .attach(MyFairing("yet more state"));
///
/// let mut fairings: Vec<_> = rocket.fairings_mut::<MyFairing>().collect();
/// assert_eq!(fairings.len(), 3);
/// assert_eq!(fairings[0].0, "some state");
/// assert_eq!(fairings[1].0, "other state");
/// assert_eq!(fairings[2].0, "yet more state");
///
/// // we can modify the fairings
/// fairings[1].0 = "modified state";
///
/// let fairings: Vec<_> = rocket.fairings::<MyFairing>().collect();
/// assert_eq!(fairings.len(), 3);
/// assert_eq!(fairings[0].0, "some state");
/// assert_eq!(fairings[1].0, "modified state");
/// assert_eq!(fairings[2].0, "yet more state");
/// ```
pub fn fairings_mut<F: Fairing>(&mut self) -> impl Iterator<Item = &mut F> {
match self.0.as_mut() {
StateRefMut::Build(p) => Either::Left(p.fairings.filter_mut::<F>()),
StateRefMut::Ignite(p) => Either::Right(p.fairings.filter_mut::<F>()),
StateRefMut::Orbit(p) => Either::Right(p.fairings.filter_mut::<F>()),
}
}
/// Returns the figment derived from the configuration provider set for /// Returns the figment derived from the configuration provider set for
/// `self`. To extract a typed config, prefer to use /// `self`. To extract a typed config, prefer to use
/// [`AdHoc::config()`](crate::fairing::AdHoc::config()). /// [`AdHoc::config()`](crate::fairing::AdHoc::config()).
@ -910,7 +1117,7 @@ impl<P: Phase> Rocket<P> {
/// let figment = rocket.figment(); /// let figment = rocket.figment();
/// ``` /// ```
pub fn figment(&self) -> &Figment { pub fn figment(&self) -> &Figment {
match self.0.as_state_ref() { match self.0.as_ref() {
StateRef::Build(p) => &p.figment, StateRef::Build(p) => &p.figment,
StateRef::Ignite(p) => &p.figment, StateRef::Ignite(p) => &p.figment,
StateRef::Orbit(p) => &p.figment, StateRef::Orbit(p) => &p.figment,