Replace use of 'hyper::AddrIncoming' with a Listener API and implement TLS.

Types can now implement the new 'Listener' trait, which means they can
report the address they are listening on and asynchronously accept
connections. 'Connection's are read/write streams that can additionally
report the remote address.

Listener is implemented for 'tokio_net::tcp::TcpListener' and for
the new 'rocket_http::tls::TlsListener' based on 'tokio-rustls'.

The new private function 'Rocket::listen_on()' now does the main setup
for launch and is generic over a Listener. In the future, a more refined
version of the API can be exposed so that applications can implement
their own listeners.
This commit is contained in:
Jeb Rosen 2019-09-14 08:10:00 -07:00 committed by Sergio Benitez
parent 62a99e9e49
commit 523c6099fb
9 changed files with 464 additions and 113 deletions

View File

@ -22,16 +22,21 @@ private-cookies = ["cookie/private", "cookie/key-expansion"]
[dependencies] [dependencies]
smallvec = "1.0" smallvec = "1.0"
percent-encoding = "1" percent-encoding = "1"
hyper = { version = "=0.13.0-alpha.1", default-features = false, features = ["runtime"] } hyper = { version = "=0.13.0-alpha.1", default-features = false }
http = "0.1.17" http = "0.1.17"
mime = "0.3.13" mime = "0.3.13"
time = "0.2.11" time = "0.2.11"
indexmap = "1.0" indexmap = "1.0"
state = "0.4" state = "0.4"
tokio-rustls = { version = "0.10.3", optional = true } tokio-rustls = { version = "0.12.0-alpha.2", optional = true }
tokio-io = "=0.2.0-alpha.4"
tokio-net = "=0.2.0-alpha.4"
tokio-timer = "=0.3.0-alpha.4"
cookie = { version = "0.14.0", features = ["percent-encode"] } cookie = { version = "0.14.0", features = ["percent-encode"] }
pear = "0.1" pear = "0.1"
unicode-xid = "0.2" unicode-xid = "0.2"
futures-preview = "0.3.0-alpha.18"
log = "0.4"
[dev-dependencies] [dev-dependencies]
rocket = { version = "0.5.0-dev", path = "../lib" } rocket = { version = "0.5.0-dev", path = "../lib" }

View File

@ -8,7 +8,6 @@
#[doc(hidden)] pub use hyper::body::{Payload, Sender as BodySender}; #[doc(hidden)] pub use hyper::body::{Payload, Sender as BodySender};
#[doc(hidden)] pub use hyper::error::Error; #[doc(hidden)] pub use hyper::error::Error;
#[doc(hidden)] pub use hyper::service::{make_service_fn, service_fn, MakeService, Service}; #[doc(hidden)] pub use hyper::service::{make_service_fn, service_fn, MakeService, Service};
#[doc(hidden)] pub use hyper::server::conn::{AddrIncoming, AddrStream};
#[doc(hidden)] pub use hyper::Chunk; #[doc(hidden)] pub use hyper::Chunk;
#[doc(hidden)] pub use http::header::HeaderMap; #[doc(hidden)] pub use http::header::HeaderMap;

View File

@ -37,6 +37,7 @@ mod header;
mod accept; mod accept;
mod raw_str; mod raw_str;
mod parse; mod parse;
mod listener;
pub mod uncased; pub mod uncased;
@ -51,6 +52,7 @@ pub mod private {
// This one we need to expose for core. // This one we need to expose for core.
pub use crate::cookies::{Key, CookieJar}; pub use crate::cookies::{Key, CookieJar};
pub use crate::listener::{Incoming, Listener, Connection, bind_tcp};
} }
pub use crate::method::Method; pub use crate::method::Method;

187
core/http/src/listener.rs Normal file
View File

@ -0,0 +1,187 @@
use std::fmt;
use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use futures::ready;
use futures::stream::Stream;
use log::{debug, error};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_timer::Delay;
use tokio_net::tcp::{TcpListener, TcpStream};
// TODO.async: 'Listener' and 'Connection' provide common enough functionality
// that they could be introduced in upstream libraries.
/// A 'Listener' yields incoming connections
pub trait Listener {
type Connection: Connection;
/// Return the actual address this listener bound to.
fn local_addr(&self) -> Option<SocketAddr>;
/// Try to accept an incoming Connection if ready
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<Self::Connection, io::Error>>;
}
/// A 'Connection' represents an open connection to a client
pub trait Connection: AsyncRead + AsyncWrite {
fn remote_addr(&self) -> Option<SocketAddr>;
}
/// This is a genericized version of hyper's AddrIncoming that is intended to be
/// usable with listeners other than a plain TCP stream, e.g. TLS and/or Unix
/// sockets. It does this by briding the `Listener` trait to what hyper wants (a
/// Stream of AsyncRead+AsyncWrite). This type is internal to Rocket.
#[must_use = "streams do nothing unless polled"]
pub struct Incoming<L> {
listener: L,
sleep_on_errors: Option<Duration>,
pending_error_delay: Option<Delay>,
}
impl<L: Listener> Incoming<L> {
/// Construct an `Incoming` from an existing `Listener`.
pub fn from_listener(listener: L) -> Self {
Self {
listener,
sleep_on_errors: Some(Duration::from_secs(1)),
pending_error_delay: None,
}
}
/// Set whether to sleep on accept errors.
///
/// A possible scenario is that the process has hit the max open files
/// allowed, and so trying to accept a new connection will fail with
/// `EMFILE`. In some cases, it's preferable to just wait for some time, if
/// the application will likely close some files (or connections), and try
/// to accept the connection again. If this option is `true`, the error
/// will be logged at the `error` level, since it is still a big deal,
/// and then the listener will sleep for 1 second.
///
/// In other cases, hitting the max open files should be treat similarly
/// to being out-of-memory, and simply error (and shutdown). Setting
/// this option to `None` will allow that.
///
/// Default is 1 second.
pub fn set_sleep_on_errors(&mut self, val: Option<Duration>) {
self.sleep_on_errors = val;
}
fn poll_next_(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<L::Connection>> {
// Check if a previous delay is active that was set by IO errors.
if let Some(ref mut delay) = self.pending_error_delay {
match Pin::new(delay).poll(cx) {
Poll::Ready(()) => {}
Poll::Pending => return Poll::Pending,
}
}
self.pending_error_delay = None;
loop {
match self.listener.poll_accept(cx) {
Poll::Ready(Ok(stream)) => {
return Poll::Ready(Ok(stream));
},
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => {
// Connection errors can be ignored directly, continue by
// accepting the next request.
if is_connection_error(&e) {
debug!("accepted connection already errored: {}", e);
continue;
}
if let Some(duration) = self.sleep_on_errors {
error!("accept error: {}", e);
// Sleep for the specified duration
let delay = Instant::now() + duration;
// TODO.async: This depends on a tokio Timer being set in the environment
let mut error_delay = tokio_timer::delay(delay);
match Pin::new(&mut error_delay).poll(cx) {
Poll::Ready(()) => {
// Wow, it's been a second already? Ok then...
continue
},
Poll::Pending => {
self.pending_error_delay = Some(error_delay);
return Poll::Pending;
},
}
} else {
return Poll::Ready(Err(e));
}
},
}
}
}
}
impl<L: Listener + Unpin> Stream for Incoming<L> {
type Item = io::Result<L::Connection>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let result = ready!(self.poll_next_(cx));
Poll::Ready(Some(result))
}
}
/// This function defines errors that are per-connection. Which basically
/// means that if we get this error from `accept()` system call it means
/// next connection might be ready to be accepted.
///
/// All other errors will incur a delay before next `accept()` is performed.
/// The delay is useful to handle resource exhaustion errors like ENFILE
/// and EMFILE. Otherwise, could enter into tight loop.
fn is_connection_error(e: &io::Error) -> bool {
match e.kind() {
io::ErrorKind::ConnectionRefused |
io::ErrorKind::ConnectionAborted |
io::ErrorKind::ConnectionReset => true,
_ => false,
}
}
impl<L: fmt::Debug> fmt::Debug for Incoming<L> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Incoming")
.field("listener", &self.listener)
.finish()
}
}
// TODO.async: Put these under a feature such as #[cfg(feature = "tokio-runtime")]
pub fn bind_tcp(address: SocketAddr) -> Pin<Box<dyn Future<Output=Result<TcpListener, io::Error>> + Send>> {
Box::pin(async move {
Ok(TcpListener::bind(address).await?)
})
}
impl Listener for TcpListener {
type Connection = TcpStream;
fn local_addr(&self) -> Option<SocketAddr> {
self.local_addr().ok()
}
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<Self::Connection, io::Error>> {
// NB: This is only okay because TcpListener::accept() is stateless.
let accept = self.accept();
futures::pin_mut!(accept);
accept.poll(cx).map_ok(|(stream, _addr)| stream)
}
}
impl Connection for TcpStream {
fn remote_addr(&self) -> Option<SocketAddr> {
self.peer_addr().ok()
}
}

View File

@ -1,8 +1,146 @@
pub use tokio_rustls::TlsAcceptor; use std::fs;
pub use tokio_rustls::rustls; use std::future::Future;
use std::io::{self, BufReader};
use std::net::SocketAddr;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio_net::tcp::{TcpListener, TcpStream};
use tokio_rustls::{TlsAcceptor, server::TlsStream};
use tokio_rustls::rustls;
pub use rustls::internal::pemfile; pub use rustls::internal::pemfile;
pub use rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig}; pub use rustls::{Certificate, PrivateKey, ServerConfig};
// TODO.async: extract from hyper-sync-rustls some convenience use crate::listener::{Connection, Listener};
// functions to load certs and keys
#[derive(Debug)]
pub enum Error {
Io(io::Error),
BadCerts,
BadKeyCount,
BadKey,
}
// TODO.async: consider using async fs operations
pub fn load_certs<P: AsRef<Path>>(path: P) -> Result<Vec<rustls::Certificate>, Error> {
let certfile = fs::File::open(path.as_ref()).map_err(|e| Error::Io(e))?;
let mut reader = BufReader::new(certfile);
pemfile::certs(&mut reader).map_err(|_| Error::BadCerts)
}
pub fn load_private_key<P: AsRef<Path>>(path: P) -> Result<rustls::PrivateKey, Error> {
use std::io::Seek;
use std::io::BufRead;
let keyfile = fs::File::open(path.as_ref()).map_err(Error::Io)?;
let mut reader = BufReader::new(keyfile);
// "rsa" (PKCS1) PEM files have a different first-line header than PKCS8
// PEM files, use that to determine the parse function to use.
let mut first_line = String::new();
reader.read_line(&mut first_line).map_err(Error::Io)?;
reader.seek(io::SeekFrom::Start(0)).map_err(Error::Io)?;
let private_keys_fn = match first_line.trim_end() {
"-----BEGIN RSA PRIVATE KEY-----" => pemfile::rsa_private_keys,
"-----BEGIN PRIVATE KEY-----" => pemfile::pkcs8_private_keys,
_ => return Err(Error::BadKey),
};
let key = private_keys_fn(&mut reader)
.map_err(|_| Error::BadKey)
.and_then(|mut keys| match keys.len() {
0 => Err(Error::BadKey),
1 => Ok(keys.remove(0)),
_ => Err(Error::BadKeyCount),
})?;
// Ensure we can use the key.
if rustls::sign::RSASigningKey::new(&key).is_err() {
Err(Error::BadKey)
} else {
Ok(key)
}
}
// TODO.async: Put these under a feature such as #[cfg(feature = "tokio-runtime")]
pub struct TlsListener {
listener: TcpListener,
acceptor: TlsAcceptor,
state: TlsListenerState,
}
enum TlsListenerState {
Listening,
Accepting(Pin<Box<dyn Future<Output=Result<TlsStream<TcpStream>, io::Error>> + Send>>),
}
impl Listener for TlsListener {
type Connection = TlsStream<TcpStream>;
fn local_addr(&self) -> Option<SocketAddr> {
self.listener.local_addr().ok()
}
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<Self::Connection, io::Error>> {
loop {
match &mut self.state {
TlsListenerState::Listening => {
match self.listener.poll_accept(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(stream)) => {
self.state = TlsListenerState::Accepting(Box::pin(self.acceptor.accept(stream)));
}
Poll::Ready(Err(e)) => {
return Poll::Ready(Err(e));
}
}
}
TlsListenerState::Accepting(fut) => {
match fut.as_mut().poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(result) => {
self.state = TlsListenerState::Listening;
return Poll::Ready(result);
}
}
}
}
}
}
}
pub fn bind_tls(address: SocketAddr, cert_chain: Vec<Certificate>, key: PrivateKey)
-> Pin<Box<dyn Future<Output=Result<TlsListener, io::Error>> + Send>>
{
Box::pin(async move {
let listener = TcpListener::bind(address).await?;
let client_auth = rustls::NoClientAuth::new();
let mut tls_config = ServerConfig::new(client_auth);
let cache = rustls::ServerSessionMemoryCache::new(1024);
tls_config.set_persistence(cache);
tls_config.ticketer = rustls::Ticketer::new();
tls_config.set_single_cert(cert_chain, key).expect("invalid key");
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
Ok(TlsListener {
listener,
acceptor,
state: TlsListenerState::Listening,
})
})
}
impl Connection for TlsStream<TcpStream> {
fn remote_addr(&self) -> Option<SocketAddr> {
self.get_ref().0.remote_addr()
}
}

View File

@ -574,33 +574,23 @@ impl Config {
/// ``` /// ```
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
pub fn set_tls(&mut self, certs_path: &str, key_path: &str) -> Result<()> { pub fn set_tls(&mut self, certs_path: &str, key_path: &str) -> Result<()> {
use crate::http::tls::pemfile::{certs, rsa_private_keys}; use crate::http::tls::{load_certs, load_private_key, Error};
use std::fs::File;
use std::io::BufReader;
let pem_err = "malformed PEM file"; let pem_err = "malformed PEM file";
// TODO.async: Fully copy from hyper-sync-rustls, move to http/src/tls
// Partially extracted from hyper-sync-rustls
// Load the certificates. // Load the certificates.
let certs = match File::open(self.root_relative(certs_path)) { let certs = load_certs(self.root_relative(certs_path))
Ok(file) => certs(&mut BufReader::new(file)).map_err(|_| { .map_err(|e| match e {
self.bad_type("tls", pem_err, "a valid certificates file") Error::Io(e) => ConfigError::Io(e, "tls.certs"),
}), _ => self.bad_type("tls", pem_err, "a valid certificates file")
Err(e) => Err(ConfigError::Io(e, "tls.certs"))?, })?;
}?;
// And now the private key. // And now the private key.
let mut keys = match File::open(self.root_relative(key_path)) { let key = load_private_key(self.root_relative(key_path))
Ok(file) => rsa_private_keys(&mut BufReader::new(file)).map_err(|_| { .map_err(|e| match e {
self.bad_type("tls", pem_err, "a valid private key file") Error::Io(e) => ConfigError::Io(e, "tls.key"),
}), _ => self.bad_type("tls", pem_err, "a valid private key file")
Err(e) => Err(ConfigError::Io(e, "tls.key")), })?;
}?;
// TODO.async: Proper check for one key
let key = keys.remove(0);
self.tls = Some(TlsConfig { certs, key }); self.tls = Some(TlsConfig { certs, key });
Ok(()) Ok(())

View File

@ -76,6 +76,8 @@ impl Data {
pub(crate) fn from_hyp(body: hyper::Body) -> impl Future<Output = Data> { pub(crate) fn from_hyp(body: hyper::Body) -> impl Future<Output = Data> {
// TODO.async: This used to also set the read timeout to 5 seconds. // TODO.async: This used to also set the read timeout to 5 seconds.
// Such a short read timeout is likely no longer necessary, but some
// kind of idle timeout should be implemented.
Data::new(body) Data::new(body)
} }

View File

@ -26,7 +26,7 @@ pub enum Error {
#[derive(Debug)] #[derive(Debug)]
pub enum LaunchErrorKind { pub enum LaunchErrorKind {
/// Binding to the provided address/port failed. /// Binding to the provided address/port failed.
Bind(hyper::Error), Bind(io::Error),
/// An I/O error occurred during launch. /// An I/O error occurred during launch.
Io(io::Error), Io(io::Error),
/// Route collisions were detected. /// Route collisions were detected.

View File

@ -3,21 +3,17 @@ use std::convert::{From, TryInto};
use std::cmp::min; use std::cmp::min;
use std::io; use std::io;
use std::mem; use std::mem;
use std::net::ToSocketAddrs;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use futures::future::{Future, FutureExt, BoxFuture}; use futures::future::{Future, FutureExt, BoxFuture};
use futures::channel::{mpsc, oneshot}; use futures::channel::{mpsc, oneshot};
use futures::stream::StreamExt; use futures::stream::StreamExt;
use futures::task::SpawnExt; use futures::task::{Spawn, SpawnExt};
use futures_tokio_compat::Compat as TokioCompat; use futures_tokio_compat::Compat as TokioCompat;
use yansi::Paint; use yansi::Paint;
use state::Container; use state::Container;
#[cfg(feature = "tls")] use crate::http::tls::TlsAcceptor;
use crate::{logger, handler}; use crate::{logger, handler};
use crate::config::{Config, FullConfig, ConfigError, LoggedValue}; use crate::config::{Config, FullConfig, ConfigError, LoggedValue};
use crate::request::{Request, FormItems}; use crate::request::{Request, FormItems};
@ -33,6 +29,7 @@ use crate::ext::AsyncReadExt;
use crate::shutdown::{ShutdownHandle, ShutdownHandleManaged}; use crate::shutdown::{ShutdownHandle, ShutdownHandleManaged};
use crate::http::{Method, Status, Header}; use crate::http::{Method, Status, Header};
use crate::http::private::{Listener, Connection, Incoming};
use crate::http::hyper::{self, header}; use crate::http::hyper::{self, header};
use crate::http::uri::Origin; use crate::http::uri::Origin;
@ -711,61 +708,32 @@ impl Rocket {
Ok(self) Ok(self)
} }
/// Similar to `launch()`, but using a custom Tokio runtime and returning // TODO.async: Solidify the Listener APIs and make this function public
/// a `Future` that completes along with the server. The runtime has no async fn listen_on<L, S>(mut self, listener: L, spawn: S) -> Result<(), crate::error::Error>
/// restrictions other than being Tokio-based, and can have other tasks where
/// running on it. L: Listener + Send + Unpin + 'static,
/// <L as Listener>::Connection: Send + Unpin + 'static,
/// # Example S: Spawn + Clone + Send + 'static,
/// {
/// ```rust self = self.prelaunch_check().map_err(crate::error::Error::Launch)?;
/// use futures::future::FutureExt;
///
/// // This gives us the default behavior. Alternatively, we could use a
/// // `tokio::runtime::Builder` to configure with greater detail.
/// let runtime = tokio::runtime::Runtime::new().expect("error creating runtime");
///
/// # if false {
/// let server_done = rocket::ignite().spawn_on(&runtime).expect("error launching server");
/// runtime.block_on(async move {
/// let result = server_done.await;
/// assert!(result.is_ok());
/// });
/// # }
/// ```
pub fn spawn_on(
mut self,
runtime: &tokio::runtime::Runtime,
) -> Result<impl Future<Output = Result<(), hyper::Error>>, LaunchError> {
#[cfg(feature = "tls")] use crate::http::tls;
self = self.prelaunch_check()?;
self.fairings.pretty_print_counts(); self.fairings.pretty_print_counts();
let full_addr = format!("{}:{}", self.config.address, self.config.port);
let addrs = match full_addr.to_socket_addrs() {
Ok(a) => a.collect::<Vec<_>>(),
// TODO.async: Reconsider this error type
Err(e) => return Err(From::from(io::Error::new(io::ErrorKind::Other, e))),
};
// TODO.async: support for TLS, unix sockets.
// Likely will be implemented with a custom "Incoming" type.
let mut incoming = match hyper::AddrIncoming::bind(&addrs[0]) {
Ok(incoming) => incoming,
Err(e) => return Err(LaunchError::new(LaunchErrorKind::Bind(e))),
};
// Determine the address and port we actually binded to. // Determine the address and port we actually binded to.
self.config.port = incoming.local_addr().port(); self.config.port = listener.local_addr().map(|a| a.port()).unwrap_or(0);
let proto = "http://"; let proto = if self.config.tls.is_some() {
"https://"
} else {
"http://"
};
let full_addr = format!("{}:{}", self.config.address, self.config.port);
// Set the keep-alive. // Set the keep-alive.
let timeout = self.config.keep_alive.map(|s| Duration::from_secs(s as u64)); // TODO.async: implement keep-alive in Listener
incoming.set_keepalive(timeout); // let timeout = self.config.keep_alive.map(|s| Duration::from_secs(s as u64));
// listener.set_keepalive(timeout);
// Freeze managed state for synchronization-free accesses later. // Freeze managed state for synchronization-free accesses later.
self.state.freeze(); self.state.freeze();
@ -782,41 +750,108 @@ impl Rocket {
// Restore the log level back to what it originally was. // Restore the log level back to what it originally was.
logger::pop_max_level(); logger::pop_max_level();
// We need to get these values before moving `self` into an `Arc`. // We need to get this before moving `self` into an `Arc`.
let mut shutdown_receiver = self.shutdown_receiver let mut shutdown_receiver = self.shutdown_receiver
.take().expect("shutdown receiver has already been used"); .take().expect("shutdown receiver has already been used");
#[cfg(feature = "ctrl_c_shutdown")]
let shutdown_handle = self.get_shutdown_handle();
let rocket = Arc::new(self); let rocket = Arc::new(self);
let spawn = Box::new(TokioCompat::new(runtime.executor())); let spawn_makeservice = spawn.clone();
let service = hyper::make_service_fn(move |socket: &hyper::AddrStream| { let service = hyper::make_service_fn(move |connection: &<L as Listener>::Connection| {
let rocket = rocket.clone(); let rocket = rocket.clone();
let remote_addr = socket.remote_addr(); let remote_addr = connection.remote_addr().unwrap_or_else(|| "0.0.0.0".parse().unwrap());
let spawn = spawn.clone(); let spawn_service = spawn_makeservice.clone();
async move { async move {
Ok::<_, std::convert::Infallible>(hyper::service_fn(move |req| { Ok::<_, std::convert::Infallible>(hyper::service_fn(move |req| {
hyper_service_fn(rocket.clone(), remote_addr, spawn.clone(), req) hyper_service_fn(rocket.clone(), remote_addr, spawn_service.clone(), req)
})) }))
} }
}); });
#[cfg(feature = "ctrl_c_shutdown")]
let (cancel_ctrl_c_listener_sender, cancel_ctrl_c_listener_receiver) = oneshot::channel();
// NB: executor must be passed manually here, see hyperium/hyper#1537 // NB: executor must be passed manually here, see hyperium/hyper#1537
let (future, handle) = hyper::Server::builder(incoming) hyper::Server::builder(Incoming::from_listener(listener))
.executor(runtime.executor()) .executor(TokioCompat::new(spawn))
.serve(service) .serve(service)
.with_graceful_shutdown(async move { shutdown_receiver.next().await; }) .with_graceful_shutdown(async move { shutdown_receiver.next().await; })
.inspect(|_| { .await
#[cfg(feature = "ctrl_c_shutdown")] .map_err(crate::error::Error::Run)
let _ = cancel_ctrl_c_listener_sender.send(()); }
})
.remote_handle();
runtime.spawn(future); /// Similar to `launch()`, but using a custom Tokio runtime and returning
/// a `Future` that completes along with the server. The runtime has no
/// restrictions other than being Tokio-based, and can have other tasks
/// running on it.
///
/// # Example
///
/// ```rust
/// use futures::future::FutureExt;
///
/// // This gives us the default behavior. Alternatively, we could use a
/// // `tokio::runtime::Builder` to configure with greater detail.
/// let runtime = tokio::runtime::Runtime::new().expect("error creating runtime");
///
/// # if false {
/// let server_done = rocket::ignite().spawn_on(&runtime);
/// runtime.block_on(async move {
/// let result = server_done.await;
/// assert!(result.is_ok());
/// });
/// # }
/// ```
pub fn spawn_on(
self,
runtime: &tokio::runtime::Runtime,
) -> impl Future<Output = Result<(), crate::error::Error>> {
use std::net::ToSocketAddrs;
use crate::error::Error::Launch;
let full_addr = format!("{}:{}", self.config.address, self.config.port);
let addrs = match full_addr.to_socket_addrs() {
Ok(a) => a.collect::<Vec<_>>(),
Err(e) => return futures::future::err(Launch(From::from(e))).boxed(),
};
let addr = addrs[0];
let spawn = TokioCompat::new(runtime.executor());
#[cfg(feature = "ctrl_c_shutdown")]
let (
shutdown_handle,
(cancel_ctrl_c_listener_sender, cancel_ctrl_c_listener_receiver)
) = (
self.get_shutdown_handle(),
oneshot::channel()
);
let server = async move {
macro_rules! listen_on {
($spawn:expr, $expr:expr) => {{
let listener = match $expr {
Ok(ok) => ok,
Err(err) => return Err(Launch(LaunchError::new(LaunchErrorKind::Bind(err)))),
};
self.listen_on(listener, spawn).await
}};
}
#[cfg(feature = "tls")]
{
if let Some(tls) = self.config.tls.clone() {
listen_on!(spawn, crate::http::tls::bind_tls(addr, tls.certs, tls.key).await)
} else {
listen_on!(spawn, crate::http::private::bind_tcp(addr).await)
}
}
#[cfg(not(feature = "tls"))]
{
listen_on!(spawn, crate::http::private::bind_tcp(addr).await)
}
};
#[cfg(feature = "ctrl_c_shutdown")]
let server = server.inspect(|_| {
let _ = cancel_ctrl_c_listener_sender.send(());
});
#[cfg(feature = "ctrl_c_shutdown")] #[cfg(feature = "ctrl_c_shutdown")]
match tokio::net::signal::ctrl_c() { match tokio::net::signal::ctrl_c() {
@ -843,7 +878,7 @@ impl Rocket {
}, },
} }
Ok(handle) server.boxed()
} }
/// Starts the application server and begins listening for and dispatching /// Starts the application server and begins listening for and dispatching
@ -866,8 +901,6 @@ impl Rocket {
/// # } /// # }
/// ``` /// ```
pub fn launch(self) -> Result<(), crate::error::Error> { pub fn launch(self) -> Result<(), crate::error::Error> {
use crate::error::Error;
// TODO.async What meaning should config.workers have now? // TODO.async What meaning should config.workers have now?
// Initialize the tokio runtime // Initialize the tokio runtime
let runtime = tokio::runtime::Builder::new() let runtime = tokio::runtime::Builder::new()
@ -875,10 +908,7 @@ impl Rocket {
.build() .build()
.expect("Cannot build runtime!"); .expect("Cannot build runtime!");
match self.spawn_on(&runtime) { runtime.block_on(self.spawn_on(&runtime))
Ok(fut) => runtime.block_on(fut).map_err(Error::Run),
Err(err) => Err(Error::Launch(err)),
}
} }
/// Returns a [`ShutdownHandle`], which can be used to gracefully terminate /// Returns a [`ShutdownHandle`], which can be used to gracefully terminate
@ -893,19 +923,17 @@ impl Rocket {
/// # /// #
/// let rocket = rocket::ignite(); /// let rocket = rocket::ignite();
/// let handle = rocket.get_shutdown_handle(); /// let handle = rocket.get_shutdown_handle();
/// # let real_handle = rocket.get_shutdown_handle();
/// ///
/// # if false { /// # if false {
/// thread::spawn(move || { /// thread::spawn(move || {
/// thread::sleep(Duration::from_secs(10)); /// thread::sleep(Duration::from_secs(10));
/// handle.shutdown(); /// handle.shutdown();
/// }); /// });
/// # }
/// # real_handle.shutdown();
/// ///
/// // Shuts down after 10 seconds /// // Shuts down after 10 seconds
/// let shutdown_result = rocket.launch(); /// let shutdown_result = rocket.launch();
/// assert!(shutdown_result.is_ok()); /// assert!(shutdown_result.is_ok());
/// # }
/// ``` /// ```
#[inline(always)] #[inline(always)]
pub fn get_shutdown_handle(&self) -> ShutdownHandle { pub fn get_shutdown_handle(&self) -> ShutdownHandle {