diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index 5f3a6961..67a03ff5 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -19,9 +19,10 @@ edition = "2018" all-features = true [features] -default = ["private-cookies"] +default = ["private-cookies", "ctrl_c_shutdown"] tls = ["rocket_http/tls"] private-cookies = ["rocket_http/private-cookies"] +ctrl_c_shutdown = ["tokio/signal"] [dependencies] rocket_codegen = { version = "0.5.0-dev", path = "../codegen" } diff --git a/core/lib/src/error.rs b/core/lib/src/error.rs index c3ed1721..12bd961e 100644 --- a/core/lib/src/error.rs +++ b/core/lib/src/error.rs @@ -8,6 +8,13 @@ use yansi::Paint; use crate::http::hyper; use crate::router::Route; +// TODO.async docs +#[derive(Debug)] +pub enum Error { + Launch(LaunchError), + Run(hyper::Error), +} + /// The kind of launch error that occurred. /// /// In almost every instance, a launch error occurs because of an I/O error; @@ -44,7 +51,9 @@ pub enum LaunchErrorKind { /// as inspected; a subsequent `drop` of the value will _not_ result in a panic. /// The following snippet illustrates this: /// -/// ```rust +// TODO.async This isn't true any more, as `.launch()` now returns a +// `Result<(), crate::error::Error>`, which could also be a runtime error. +/// ```rust,ignore /// # if false { /// let error = rocket::ignite().launch(); /// @@ -106,11 +115,14 @@ impl LaunchError { /// # Example /// /// ```rust + /// use rocket::error::Error; /// # if false { - /// let error = rocket::ignite().launch(); - /// - /// // This line is only reached if launch failed. - /// let error_kind = error.kind(); + /// if let Err(error) = rocket::ignite().launch() { + /// match error { + /// Error::Launch(err) => println!("Found a launch error: {}", err.kind()), + /// Error::Run(err) => println!("Error at runtime"), + /// } + /// } /// # } /// ``` #[inline] diff --git a/core/lib/src/lib.rs b/core/lib/src/lib.rs index c3cf12ad..f343856b 100644 --- a/core/lib/src/lib.rs +++ b/core/lib/src/lib.rs @@ -108,6 +108,7 @@ pub mod data; pub mod handler; pub mod fairing; pub mod error; +pub mod shutdown; // Reexport of HTTP everything. pub mod http { diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index 793c03b6..9275b42b 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -7,7 +7,8 @@ use std::net::ToSocketAddrs; use std::sync::Arc; use std::time::Duration; -use futures::future::{Future, FutureExt, TryFutureExt, BoxFuture}; +use futures::future::{Future, FutureExt, BoxFuture}; +use futures::channel::{mpsc, oneshot}; use futures::stream::StreamExt; use futures::task::SpawnExt; use futures_tokio_compat::Compat as TokioCompat; @@ -29,6 +30,7 @@ use crate::error::{LaunchError, LaunchErrorKind}; use crate::fairing::{Fairing, Fairings}; use crate::logger::PaintExt; use crate::ext::AsyncReadExt; +use crate::shutdown::{ShutdownHandle, ShutdownHandleManaged}; use crate::http::{Method, Status, Header}; use crate::http::hyper::{self, header}; @@ -43,6 +45,8 @@ pub struct Rocket { catchers: HashMap, pub(crate) state: Container, fairings: Fairings, + shutdown_handle: ShutdownHandle, + shutdown_receiver: Option>, } // This function tries to hide all of the Hyper-ness from Rocket. It @@ -464,14 +468,22 @@ impl Rocket { Paint::default(LoggedValue(value)).bold()); } - Rocket { + let (shutdown_sender, shutdown_receiver) = mpsc::channel(1); + + let rocket = Rocket { config, router: Router::new(), default_catchers: catcher::defaults::get(), catchers: catcher::defaults::get(), state: Container::new(), fairings: Fairings::new(), - } + shutdown_handle: ShutdownHandle(shutdown_sender), + shutdown_receiver: Some(shutdown_receiver), + }; + + rocket.state.set(ShutdownHandleManaged(rocket.shutdown_handle.clone())); + + rocket } /// Mounts all of the routes in the supplied vector at the given `base` @@ -721,11 +733,10 @@ impl Rocket { /// }); /// # } /// ``` - // TODO.async Decide on an return type, possibly creating a discriminated union. pub fn spawn_on( mut self, runtime: &tokio::runtime::Runtime, - ) -> Result>>, LaunchError> { + ) -> Result>, LaunchError> { #[cfg(feature = "tls")] use crate::http::tls; self = self.prelaunch_check()?; @@ -771,6 +782,11 @@ impl Rocket { // Restore the log level back to what it originally was. logger::pop_max_level(); + // We need to get these values before moving `self` into an `Arc`. + let mut shutdown_receiver = self.shutdown_receiver + .take().expect("shutdown receiver has already been used"); + let shutdown_handle = self.get_shutdown_handle(); + let rocket = Arc::new(self); let spawn = Box::new(TokioCompat::new(runtime.executor())); let service = hyper::make_service_fn(move |socket: &hyper::AddrStream| { @@ -784,19 +800,54 @@ impl Rocket { } }); - // NB: executor must be passed manually here, see hyperium/hyper#1537 - let server = hyper::Server::builder(incoming) - .executor(runtime.executor()) - .serve(service); + #[cfg(feature = "ctrl_c_shutdown")] + let (cancel_ctrl_c_listener_sender, cancel_ctrl_c_listener_receiver) = oneshot::channel(); + + // NB: executor must be passed manually here, see hyperium/hyper#1537 + let (future, handle) = hyper::Server::builder(incoming) + .executor(runtime.executor()) + .serve(service) + .with_graceful_shutdown(async move { shutdown_receiver.next().await; }) + .inspect(|_| { + #[cfg(feature = "ctrl_c_shutdown")] + let _ = cancel_ctrl_c_listener_sender.send(()); + }) + .remote_handle(); - let (future, handle) = server.remote_handle(); runtime.spawn(future); - Ok(handle.err_into()) + + #[cfg(feature = "ctrl_c_shutdown")] + match tokio::net::signal::ctrl_c() { + Ok(mut ctrl_c) => { + runtime.spawn(async move { + // Stop listening for `ctrl_c` if the server shuts down + // a different way to avoid waiting forever. + futures::future::select( + ctrl_c.next(), + cancel_ctrl_c_listener_receiver, + ).await; + + // Request the server shutdown. + shutdown_handle.shutdown(); + }); + }, + Err(err) => { + // Signal handling isn't strictly necessary, so we can skip it + // if necessary. It's a good idea to let the user know we're + // doing so in case they are expecting certain behavior. + let message = "Not listening for shutdown keybinding."; + warn!("{}", Paint::yellow(message)); + info_!("Error: {}", err); + }, + } + + Ok(handle) } /// Starts the application server and begins listening for and dispatching - /// requests to mounted routes and catchers. Unless there is an error, this - /// function does not return and blocks until program termination. + /// requests to mounted routes and catchers. This function does not return + /// unless a shutdown is requested via a [`ShutdownHandle`] or there is an + /// error. /// /// # Error /// @@ -812,8 +863,9 @@ impl Rocket { /// rocket::ignite().launch(); /// # } /// ``` - // TODO.async Decide on an return type, possibly creating a discriminated union. - pub fn launch(self) -> Box { + pub fn launch(self) -> Result<(), crate::error::Error> { + use crate::error::Error; + // TODO.async What meaning should config.workers have now? // Initialize the tokio runtime let runtime = tokio::runtime::Builder::new() @@ -821,16 +873,43 @@ impl Rocket { .build() .expect("Cannot build runtime!"); - // TODO.async: Use with_graceful_shutdown, and let launch() return a Result<(), Error> match self.spawn_on(&runtime) { - Ok(fut) => match runtime.block_on(fut) { - Ok(_) => unreachable!("the call to `block_on` should block on success"), - Err(err) => err, - } - Err(err) => Box::new(err), + Ok(fut) => runtime.block_on(fut).map_err(Error::Run), + Err(err) => Err(Error::Launch(err)), } } + /// Returns a [`ShutdownHandle`], which can be used to gracefully terminate + /// the instance of Rocket. In routes, you should use the [`ShutdownHandle`] + /// request guard. + /// + /// # Example + /// + /// ```rust + /// # #![feature(proc_macro_hygiene)] + /// # use std::{thread, time::Duration}; + /// # + /// let rocket = rocket::ignite(); + /// let handle = rocket.get_shutdown_handle(); + /// # let real_handle = rocket.get_shutdown_handle(); + /// + /// # if false { + /// thread::spawn(move || { + /// thread::sleep(Duration::from_secs(10)); + /// handle.shutdown(); + /// }); + /// # } + /// # real_handle.shutdown(); + /// + /// // Shuts down after 10 seconds + /// let shutdown_result = rocket.launch(); + /// assert!(shutdown_result.is_ok()); + /// ``` + #[inline(always)] + pub fn get_shutdown_handle(&self) -> ShutdownHandle { + self.shutdown_handle.clone() + } + /// Returns an iterator over all of the routes mounted on this instance of /// Rocket. /// diff --git a/core/lib/src/shutdown.rs b/core/lib/src/shutdown.rs new file mode 100644 index 00000000..0a3dc440 --- /dev/null +++ b/core/lib/src/shutdown.rs @@ -0,0 +1,50 @@ +use crate::request::{FromRequest, Outcome, Request}; +use futures::channel::mpsc; + +/// # Example +/// +/// ```rust +/// # #![feature(proc_macro_hygiene)] +/// # #[macro_use] extern crate rocket; +/// # +/// use rocket::shutdown::ShutdownHandle; +/// +/// #[get("/shutdown")] +/// fn shutdown(handle: ShutdownHandle) -> &'static str { +/// handle.shutdown(); +/// "Shutting down..." +/// } +/// +/// fn main() { +/// # if false { +/// rocket::ignite() +/// .mount("/", routes![shutdown]) +/// .launch() +/// .expect("server failed unexpectedly"); +/// # } +/// } +/// ``` +#[derive(Debug, Clone)] +pub struct ShutdownHandle(pub(crate) mpsc::Sender<()>); + +impl ShutdownHandle { + /// Notify Rocket to shut down gracefully. + #[inline] + pub fn shutdown(mut self) { + // Intentionally ignore any error, as the only scenarios this can happen + // is sending too many shutdown requests or we're already shut down. + let _ = self.0.try_send(()); + } +} + +impl FromRequest<'_, '_> for ShutdownHandle { + type Error = std::convert::Infallible; + + #[inline] + fn from_request(request: &Request<'_>) -> Outcome { + Outcome::Success(request.state.managed.get::().0.clone()) + } +} + +// Use this type in managed state to avoid placing `ShutdownHandle` in it. +pub(crate) struct ShutdownHandleManaged(pub ShutdownHandle); diff --git a/examples/config/src/main.rs b/examples/config/src/main.rs index 114a6b33..b96db851 100644 --- a/examples/config/src/main.rs +++ b/examples/config/src/main.rs @@ -1,4 +1,4 @@ // This example's illustration is the Rocket.toml file. fn main() { - rocket::ignite().launch(); + let _ = rocket::ignite().launch(); } diff --git a/examples/content_types/src/main.rs b/examples/content_types/src/main.rs index 7c012231..d31648be 100644 --- a/examples/content_types/src/main.rs +++ b/examples/content_types/src/main.rs @@ -65,7 +65,7 @@ fn not_found(request: &Request<'_>) -> Html { } fn main() { - rocket::ignite() + let _ = rocket::ignite() .mount("/hello", routes![get_hello, post_hello]) .register(catchers![not_found]) .launch(); diff --git a/examples/cookies/src/main.rs b/examples/cookies/src/main.rs index 1232def3..2d4e84c2 100644 --- a/examples/cookies/src/main.rs +++ b/examples/cookies/src/main.rs @@ -39,5 +39,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/errors/src/main.rs b/examples/errors/src/main.rs index 3aa27066..8cad7a21 100644 --- a/examples/errors/src/main.rs +++ b/examples/errors/src/main.rs @@ -26,5 +26,6 @@ fn main() { .launch(); println!("Whoops! Rocket didn't launch!"); - println!("This went wrong: {}", e); + // TODO.async Uncomment the following line once `.launch()`'s error type is determined. + // println!("This went wrong: {}", e); } diff --git a/examples/fairings/src/main.rs b/examples/fairings/src/main.rs index 5a3efbbc..e5cc0f37 100644 --- a/examples/fairings/src/main.rs +++ b/examples/fairings/src/main.rs @@ -96,5 +96,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/form_kitchen_sink/src/main.rs b/examples/form_kitchen_sink/src/main.rs index af494eaf..3ba791cb 100644 --- a/examples/form_kitchen_sink/src/main.rs +++ b/examples/form_kitchen_sink/src/main.rs @@ -46,5 +46,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/form_validation/src/main.rs b/examples/form_validation/src/main.rs index 0495409b..97e4dd36 100644 --- a/examples/form_validation/src/main.rs +++ b/examples/form_validation/src/main.rs @@ -82,5 +82,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/handlebars_templates/src/main.rs b/examples/handlebars_templates/src/main.rs index e2b29535..46bd9d87 100644 --- a/examples/handlebars_templates/src/main.rs +++ b/examples/handlebars_templates/src/main.rs @@ -78,5 +78,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/hello_2018/src/main.rs b/examples/hello_2018/src/main.rs index d8d7c473..1d4e967f 100644 --- a/examples/hello_2018/src/main.rs +++ b/examples/hello_2018/src/main.rs @@ -10,5 +10,5 @@ fn hello() -> &'static str { } fn main() { - rocket::ignite().mount("/", routes![hello]).launch(); + let _ = rocket::ignite().mount("/", routes![hello]).launch(); } diff --git a/examples/hello_person/src/main.rs b/examples/hello_person/src/main.rs index 96c4ae2e..98fffdbf 100644 --- a/examples/hello_person/src/main.rs +++ b/examples/hello_person/src/main.rs @@ -15,5 +15,5 @@ fn hi(name: String) -> String { } fn main() { - rocket::ignite().mount("/", routes![hello, hi]).launch(); + let _ = rocket::ignite().mount("/", routes![hello, hi]).launch(); } diff --git a/examples/hello_world/src/main.rs b/examples/hello_world/src/main.rs index 6c1111f4..1d2876fd 100644 --- a/examples/hello_world/src/main.rs +++ b/examples/hello_world/src/main.rs @@ -10,5 +10,5 @@ fn hello() -> &'static str { } fn main() { - rocket::ignite().mount("/", routes![hello]).launch(); + let _ = rocket::ignite().mount("/", routes![hello]).launch(); } diff --git a/examples/json/src/main.rs b/examples/json/src/main.rs index 65698c94..13ef804f 100644 --- a/examples/json/src/main.rs +++ b/examples/json/src/main.rs @@ -77,5 +77,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/managed_queue/src/main.rs b/examples/managed_queue/src/main.rs index d0588bc6..ba58e1a3 100644 --- a/examples/managed_queue/src/main.rs +++ b/examples/managed_queue/src/main.rs @@ -26,5 +26,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/manual_routes/src/main.rs b/examples/manual_routes/src/main.rs index 3ae0b02e..c5011599 100644 --- a/examples/manual_routes/src/main.rs +++ b/examples/manual_routes/src/main.rs @@ -118,5 +118,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/msgpack/src/main.rs b/examples/msgpack/src/main.rs index 2af8c51d..14a0a1b1 100644 --- a/examples/msgpack/src/main.rs +++ b/examples/msgpack/src/main.rs @@ -28,5 +28,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/optional_redirect/src/main.rs b/examples/optional_redirect/src/main.rs index 8c81c24a..3287db12 100644 --- a/examples/optional_redirect/src/main.rs +++ b/examples/optional_redirect/src/main.rs @@ -27,5 +27,5 @@ fn login() -> &'static str { } fn main() { - rocket::ignite().mount("/", routes![root, user, login]).launch(); + let _ = rocket::ignite().mount("/", routes![root, user, login]).launch(); } diff --git a/examples/pastebin/src/main.rs b/examples/pastebin/src/main.rs index 3be264f4..70afb0fd 100644 --- a/examples/pastebin/src/main.rs +++ b/examples/pastebin/src/main.rs @@ -56,5 +56,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/query_params/src/main.rs b/examples/query_params/src/main.rs index c2ba93af..4a8dae9c 100644 --- a/examples/query_params/src/main.rs +++ b/examples/query_params/src/main.rs @@ -37,5 +37,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/ranking/src/main.rs b/examples/ranking/src/main.rs index 2267fc3d..d92eea57 100644 --- a/examples/ranking/src/main.rs +++ b/examples/ranking/src/main.rs @@ -17,5 +17,5 @@ fn hi(name: String, age: &RawStr) -> String { } fn main() { - rocket::ignite().mount("/", routes![hi, hello]).launch(); + let _ = rocket::ignite().mount("/", routes![hi, hello]).launch(); } diff --git a/examples/raw_sqlite/src/main.rs b/examples/raw_sqlite/src/main.rs index 4daad628..242f6c11 100644 --- a/examples/raw_sqlite/src/main.rs +++ b/examples/raw_sqlite/src/main.rs @@ -49,5 +49,5 @@ fn rocket() -> Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/raw_upload/src/main.rs b/examples/raw_upload/src/main.rs index 22b25462..45ed8968 100644 --- a/examples/raw_upload/src/main.rs +++ b/examples/raw_upload/src/main.rs @@ -22,5 +22,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/redirect/src/main.rs b/examples/redirect/src/main.rs index cfd7db27..31b00e6a 100644 --- a/examples/redirect/src/main.rs +++ b/examples/redirect/src/main.rs @@ -17,5 +17,5 @@ fn login() -> &'static str { } fn main() { - rocket::ignite().mount("/", routes![root, login]).launch(); + let _ = rocket::ignite().mount("/", routes![root, login]).launch(); } diff --git a/examples/request_guard/src/main.rs b/examples/request_guard/src/main.rs index 73ca1b4a..c41c54d4 100644 --- a/examples/request_guard/src/main.rs +++ b/examples/request_guard/src/main.rs @@ -26,7 +26,7 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } #[cfg(test)] diff --git a/examples/request_local_state/src/main.rs b/examples/request_local_state/src/main.rs index f2a3ccec..93ee8bd2 100644 --- a/examples/request_local_state/src/main.rs +++ b/examples/request_local_state/src/main.rs @@ -51,5 +51,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/session/src/main.rs b/examples/session/src/main.rs index c4de5208..eb1c89e5 100644 --- a/examples/session/src/main.rs +++ b/examples/session/src/main.rs @@ -86,5 +86,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/state/src/main.rs b/examples/state/src/main.rs index 9bd48352..3f08df30 100644 --- a/examples/state/src/main.rs +++ b/examples/state/src/main.rs @@ -31,5 +31,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/static_files/src/main.rs b/examples/static_files/src/main.rs index ce4a1776..921d6f95 100644 --- a/examples/static_files/src/main.rs +++ b/examples/static_files/src/main.rs @@ -10,5 +10,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/stream/src/main.rs b/examples/stream/src/main.rs index 3596d968..eba6bb82 100644 --- a/examples/stream/src/main.rs +++ b/examples/stream/src/main.rs @@ -32,5 +32,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/tera_templates/src/main.rs b/examples/tera_templates/src/main.rs index 83c53122..cdd8490c 100644 --- a/examples/tera_templates/src/main.rs +++ b/examples/tera_templates/src/main.rs @@ -43,5 +43,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/testing/src/main.rs b/examples/testing/src/main.rs index 1e263082..c5f31cf6 100644 --- a/examples/testing/src/main.rs +++ b/examples/testing/src/main.rs @@ -12,7 +12,7 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } #[cfg(test)] diff --git a/examples/tls/src/main.rs b/examples/tls/src/main.rs index 6c1111f4..1d2876fd 100644 --- a/examples/tls/src/main.rs +++ b/examples/tls/src/main.rs @@ -10,5 +10,5 @@ fn hello() -> &'static str { } fn main() { - rocket::ignite().mount("/", routes![hello]).launch(); + let _ = rocket::ignite().mount("/", routes![hello]).launch(); } diff --git a/examples/todo/src/main.rs b/examples/todo/src/main.rs index 318a13cd..12ee7da9 100644 --- a/examples/todo/src/main.rs +++ b/examples/todo/src/main.rs @@ -115,5 +115,5 @@ fn rocket() -> Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); } diff --git a/examples/uuid/src/main.rs b/examples/uuid/src/main.rs index 38639da9..04dbe0c1 100644 --- a/examples/uuid/src/main.rs +++ b/examples/uuid/src/main.rs @@ -38,5 +38,5 @@ fn rocket() -> rocket::Rocket { } fn main() { - rocket().launch(); + let _ = rocket().launch(); }