From ab3413826cee3a7f3f678a4dd7c6f63407878a38 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Tue, 9 Mar 2021 21:13:26 -0800 Subject: [PATCH] Improve fairing example. --- examples/fairings/src/main.rs | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/examples/fairings/src/main.rs b/examples/fairings/src/main.rs index 849eb4f6..c143c55c 100644 --- a/examples/fairings/src/main.rs +++ b/examples/fairings/src/main.rs @@ -2,8 +2,9 @@ use std::io::Cursor; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; -use rocket::{Request, State, Data, Response}; +use rocket::{Rocket, Request, State, Data, Response}; use rocket::fairing::{AdHoc, Fairing, Info, Kind}; use rocket::http::{Method, ContentType, Status}; @@ -11,10 +12,10 @@ struct Token(i64); #[cfg(test)] mod tests; -#[derive(Default)] +#[derive(Default, Clone)] struct Counter { - get: AtomicUsize, - post: AtomicUsize, + get: Arc, + post: Arc, } #[rocket::async_trait] @@ -22,7 +23,7 @@ impl Fairing for Counter { fn info(&self) -> Info { Info { name: "GET/POST Counter", - kind: Kind::Request | Kind::Response + kind: Kind::Attach | Kind::Request } } @@ -34,20 +35,15 @@ impl Fairing for Counter { } } - async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) { - if res.status() != Status::NotFound { - return + async fn on_attach(&self, rocket: Rocket) -> Result { + #[get("/counts")] + fn counts(counts: State<'_, Counter>) -> String { + let get_count = counts.get.load(Ordering::Relaxed); + let post_count = counts.post.load(Ordering::Relaxed); + format!("Get: {}\nPost: {}", get_count, post_count) } - if req.method() == Method::Get && req.uri().path() == "/counts" { - let get_count = self.get.load(Ordering::Relaxed); - let post_count = self.post.load(Ordering::Relaxed); - - let body = format!("Get: {}\nPost: {}", get_count, post_count); - res.set_status(Status::Ok); - res.set_header(ContentType::Plain); - res.set_sized_body(body.len(), Cursor::new(body)); - } + Ok(rocket.manage(self.clone()).mount("/", routes![counts])) } }