From 60f3cd57b06243beaee87fd5b7545e3bf0fa6f60 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Fri, 12 Apr 2024 03:11:10 -0700 Subject: [PATCH] Add end-to-end testbench. Resolves #1509. --- .github/workflows/ci.yml | 2 + core/http/src/parse/uri/error.rs | 2 + scripts/config.sh | 2 + scripts/test.sh | 11 +- testbench/Cargo.toml | 27 ++++ testbench/src/client.rs | 206 +++++++++++++++++++++++++++++++ testbench/src/lib.rs | 3 + testbench/src/main.rs | 94 ++++++++++++++ 8 files changed, 346 insertions(+), 1 deletion(-) create mode 100644 testbench/Cargo.toml create mode 100644 testbench/src/client.rs create mode 100644 testbench/src/lib.rs create mode 100644 testbench/src/main.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 31f173bb..c4049193 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,6 +28,8 @@ jobs: test: { name: Core, flag: "--core" } - platform: { name: Linux, distro: ubuntu-latest, toolchain: stable } test: { name: Release, flag: "--release" } + - platform: { name: Linux, distro: ubuntu-latest, toolchain: stable } + test: { name: Testbench, flag: "--testbench" } - platform: { name: Linux, distro: ubuntu-latest, toolchain: stable } test: { name: UI, flag: "--ui" } fallible: true diff --git a/core/http/src/parse/uri/error.rs b/core/http/src/parse/uri/error.rs index 1deda621..c78c3491 100644 --- a/core/http/src/parse/uri/error.rs +++ b/core/http/src/parse/uri/error.rs @@ -61,6 +61,8 @@ impl IntoOwned for Error<'_> { } } +impl std::error::Error for Error<'_> { } + #[cfg(test)] mod tests { use crate::parse::uri::origin_from_str; diff --git a/scripts/config.sh b/scripts/config.sh index e4be4c55..2c700528 100755 --- a/scripts/config.sh +++ b/scripts/config.sh @@ -39,6 +39,7 @@ function future_date() { PROJECT_ROOT=$(relative "") || exit $? CONTRIB_ROOT=$(relative "contrib") || exit $? BENCHMARKS_ROOT=$(relative "benchmarks") || exit $? +TESTBENCH_ROOT=$(relative "testbench") || exit $? FUZZ_ROOT=$(relative "core/lib/fuzz") || exit $? # Root of project-like directories. @@ -87,6 +88,7 @@ function print_environment() { echo " CONTRIB_ROOT: ${CONTRIB_ROOT}" echo " FUZZ_ROOT: ${FUZZ_ROOT}" echo " BENCHMARKS_ROOT: ${BENCHMARKS_ROOT}" + echo " TESTBENCH_ROOT: ${TESTBENCH_ROOT}" echo " CORE_LIB_ROOT: ${CORE_LIB_ROOT}" echo " CORE_CODEGEN_ROOT: ${CORE_CODEGEN_ROOT}" echo " CORE_HTTP_ROOT: ${CORE_HTTP_ROOT}" diff --git a/scripts/test.sh b/scripts/test.sh index ad234c48..5212be19 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -184,6 +184,12 @@ function run_benchmarks() { indir "${BENCHMARKS_ROOT}" $CARGO bench $@ } +function run_testbench() { + echo ":: Running testbench..." + indir "${TESTBENCH_ROOT}" $CARGO update + indir "${TESTBENCH_ROOT}" $CARGO run $@ +} + if [[ $1 == +* ]]; then CARGO="$CARGO $1" shift @@ -191,7 +197,7 @@ fi # The kind of test we'll be running. TEST_KIND="default" -KINDS=("contrib" "benchmarks" "core" "examples" "default" "ui" "all") +KINDS=("contrib" "benchmarks" "testbench" "core" "examples" "default" "ui" "all") if [[ " ${KINDS[@]} " =~ " ${1#"--"} " ]]; then TEST_KIND=${1#"--"} @@ -226,12 +232,14 @@ case $TEST_KIND in examples) test_examples $@ ;; default) test_default $@ ;; benchmarks) run_benchmarks $@ ;; + testbench) run_testbench $@ ;; ui) test_ui $@ ;; all) test_default $@ & default=$! test_examples $@ & examples=$! test_core $@ & core=$! test_contrib $@ & contrib=$! + run_testbench $@ & testbench=$! test_ui $@ & ui=$! failures=() @@ -239,6 +247,7 @@ case $TEST_KIND in if ! wait $examples ; then failures+=("EXAMPLES"); fi if ! wait $core ; then failures+=("CORE"); fi if ! wait $contrib ; then failures+=("CONTRIB"); fi + if ! wait $testbench ; then failures+=("TESTBENCH"); fi if ! wait $ui ; then failures+=("UI"); fi if [ ${#failures[@]} -ne 0 ]; then diff --git a/testbench/Cargo.toml b/testbench/Cargo.toml new file mode 100644 index 00000000..57b5e9ff --- /dev/null +++ b/testbench/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "rocket-testbench" +description = "end-to-end HTTP testbench for Rocket" +version = "0.0.0" +edition = "2021" +publish = false + +[workspace] + +[dependencies] +thiserror = "1.0" +procspawn = "1" +pretty_assertions = "1.4.0" +ipc-channel = "0.18" + +[dependencies.nix] +version = "0.28" +features = ["signal"] + +[dependencies.rocket] +path = "../core/lib/" +features = ["secrets", "tls", "mtls"] + +[dependencies.reqwest] +version = "0.12.3" +default-features = false +features = ["rustls-tls-manual-roots", "charset", "cookies", "blocking", "http2"] diff --git a/testbench/src/client.rs b/testbench/src/client.rs new file mode 100644 index 00000000..ec11374d --- /dev/null +++ b/testbench/src/client.rs @@ -0,0 +1,206 @@ +use std::time::Duration; +use std::sync::Once; +use std::process::Stdio; +use std::io::{self, Read}; + +use rocket::fairing::AdHoc; +use rocket::http::ext::IntoOwned; +use rocket::http::uri::{self, Absolute, Uri}; +use rocket::serde::{Deserialize, Serialize}; +use rocket::{Build, Rocket}; + +use procspawn::SpawnError; +use thiserror::Error; +use ipc_channel::ipc::{IpcOneShotServer, IpcReceiver, IpcSender}; + +static DEFAULT_CONFIG: &str = r#" + [default] + address = "tcp:127.0.0.1" + workers = 2 + port = 0 + cli_colors = false + secret_key = "itlYmFR2vYKrOmFhupMIn/hyB6lYCCTXz4yaQX89XVg=" + + [default.shutdown] + grace = 1 + mercy = 1 +"#; + +#[derive(Debug)] +#[allow(unused)] +pub struct Client { + client: reqwest::blocking::Client, + server: procspawn::JoinHandle<()>, + tls: bool, + port: u16, + rx: IpcReceiver, +} + +#[derive(Error, Debug)] +pub enum Error { + #[error("join/kill failed: {0}")] + JoinError(#[from] SpawnError), + #[error("kill failed: {0}")] + TermFailure(#[from] nix::errno::Errno), + #[error("i/o error: {0}")] + Io(#[from] io::Error), + #[error("invalid URI: {0}")] + Uri(#[from] uri::Error<'static>), + #[error("the URI is invalid")] + InvalidUri, + #[error("bad request: {0}")] + Request(#[from] reqwest::Error), + #[error("IPC failure: {0}")] + Ipc(#[from] ipc_channel::ipc::IpcError), + #[error("liftoff failed")] + Liftoff(String, String), +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(crate = "rocket::serde")] +pub enum Message { + Liftoff(bool, u16), + Failure, +} + +#[derive(Serialize, Deserialize)] +#[serde(crate = "rocket::serde")] +#[must_use] +pub struct Token(String); + +pub type Result = std::result::Result; + +impl Token { + fn configure(&self, toml: &str, rocket: Rocket) -> Rocket { + use rocket::figment::{Figment, providers::{Format, Toml}}; + + let toml = toml.replace("{CRATE}", env!("CARGO_MANIFEST_DIR")); + let config = Figment::from(rocket.figment()) + .merge(Toml::string(DEFAULT_CONFIG).nested()) + .merge(Toml::string(&toml).nested()); + + let server = self.0.clone(); + rocket.configure(config) + .attach(AdHoc::on_liftoff("Liftoff", move |rocket| Box::pin(async move { + let tcp = rocket.endpoints().find_map(|e| e.tcp()).unwrap(); + let tls = rocket.endpoints().any(|e| e.is_tls()); + let sender = IpcSender::::connect(server).unwrap(); + let _ = sender.send(Message::Liftoff(tls, tcp.port())); + let _ = sender.send(Message::Liftoff(tls, tcp.port())); + }))) + } + + pub fn rocket(&self, toml: &str) -> Rocket { + self.configure(toml, rocket::build()) + } + + pub fn configured_launch(self, toml: &str, rocket: Rocket) { + let rocket = self.configure(toml, rocket); + if let Err(e) = rocket::execute(rocket.launch()) { + let sender = IpcSender::::connect(self.0).unwrap(); + let _ = sender.send(Message::Failure); + let _ = sender.send(Message::Failure); + e.pretty_print(); + std::process::exit(1); + } + } + + pub fn launch(self, rocket: Rocket) { + self.configured_launch(DEFAULT_CONFIG, rocket) + } +} +pub fn start(f: fn(Token)) -> Result { + static INIT: Once = Once::new(); + INIT.call_once(procspawn::init); + + let (ipc, server) = IpcOneShotServer::new()?; + let mut server = procspawn::Builder::new() + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn(Token(server), f); + + let client = reqwest::blocking::Client::builder() + .danger_accept_invalid_certs(true) + .cookie_store(true) + .tls_info(true) + .timeout(Duration::from_secs(5)) + .connect_timeout(Duration::from_secs(5)) + .build()?; + + let (rx, _) = ipc.accept().unwrap(); + match rx.recv() { + Ok(Message::Liftoff(tls, port)) => Ok(Client { client, server, tls, port, rx }), + Ok(Message::Failure) => { + let stdout = server.stdout().unwrap(); + let mut out = String::new(); + stdout.read_to_string(&mut out)?; + + let stderr = server.stderr().unwrap(); + let mut err = String::new(); + stderr.read_to_string(&mut err)?; + Err(Error::Liftoff(out, err)) + } + Err(e) => Err(e.into()), + } + +} + +pub fn default() -> Result { + start(|token| token.launch(rocket::build())) +} + +impl Client { + pub fn read_stdout(&mut self) -> Result { + let Some(stdout) = self.server.stdout() else { + return Ok(String::new()); + }; + + let mut string = String::new(); + stdout.read_to_string(&mut string)?; + Ok(string) + } + + pub fn read_stderr(&mut self) -> Result { + let Some(stderr) = self.server.stderr() else { + return Ok(String::new()); + }; + + let mut string = String::new(); + stderr.read_to_string(&mut string)?; + Ok(string) + } + + pub fn kill(&mut self) -> Result<()> { + Ok(self.server.kill()?) + } + + pub fn terminate(&mut self) -> Result<()> { + use nix::{sys::signal, unistd::Pid}; + + let pid = Pid::from_raw(self.server.pid().unwrap() as i32); + Ok(signal::kill(pid, signal::SIGTERM)?) + } + + pub fn wait(&mut self) -> Result<()> { + match self.server.join_timeout(Duration::from_secs(5)) { + Ok(_) => Ok(()), + Err(e) if e.is_remote_close() => Ok(()), + Err(e) => Err(e.into()), + } + } + + pub fn get(&self, url: &str) -> Result { + let uri = match Uri::parse_any(url).map_err(|e| e.into_owned())? { + Uri::Origin(uri) => { + let proto = if self.tls { "https" } else { "http" }; + let uri = format!("{proto}://127.0.0.1:{}{uri}", self.port); + Absolute::parse_owned(uri)? + } + Uri::Absolute(uri) => uri, + _ => return Err(Error::InvalidUri), + }; + + Ok(self.client.get(uri.to_string())) + } +} diff --git a/testbench/src/lib.rs b/testbench/src/lib.rs new file mode 100644 index 00000000..c8ab57bd --- /dev/null +++ b/testbench/src/lib.rs @@ -0,0 +1,3 @@ +pub mod client; + +pub use client::*; diff --git a/testbench/src/main.rs b/testbench/src/main.rs new file mode 100644 index 00000000..22ecd8dd --- /dev/null +++ b/testbench/src/main.rs @@ -0,0 +1,94 @@ +use rocket::{fairing::AdHoc, *}; +use rocket_testbench::client::{self, Error}; +use reqwest::tls::TlsInfo; + +fn run() -> client::Result<()> { + let mut client = client::start(|token| { + #[get("/")] + fn index() -> &'static str { + "Hello, world!" + } + + token.configured_launch(r#" + [default.tls] + certs = "{CRATE}/../examples/tls/private/rsa_sha256_cert.pem" + key = "{CRATE}/../examples/tls/private/rsa_sha256_key.pem" + "#, rocket::build().mount("/", routes![index])); + })?; + + let response = client.get("/")?.send()?; + let tls = response.extensions().get::().unwrap(); + assert!(!tls.peer_certificate().unwrap().is_empty()); + assert_eq!(response.text()?, "Hello, world!"); + + client.terminate()?; + let stdout = client.read_stdout()?; + assert!(stdout.contains("Rocket has launched on https")); + assert!(stdout.contains("Graceful shutdown completed")); + assert!(stdout.contains("GET /")); + Ok(()) +} + +fn run_fail() -> client::Result<()> { + let client = client::start(|token| { + let fail = AdHoc::try_on_ignite("FailNow", |rocket| async { Err(rocket) }); + token.launch(rocket::build().attach(fail)); + }); + + if let Err(Error::Liftoff(stdout, _)) = client { + assert!(stdout.contains("Rocket failed to launch due to failing fairings")); + assert!(stdout.contains("FailNow")); + } else { + panic!("unexpected result: {client:#?}"); + } + + Ok(()) +} + +fn infinite() -> client::Result<()> { + use rocket::response::stream::TextStream; + + let mut client = client::start(|token| { + #[get("/")] + fn infinite() -> TextStream![&'static str] { + TextStream! { + loop { + yield rocket::futures::future::pending::<&str>().await; + } + } + } + + token.launch(rocket::build().mount("/", routes![infinite])); + })?; + + client.get("/")?.send()?; + client.terminate()?; + let stdout = client.read_stdout()?; + assert!(stdout.contains("Rocket has launched on http")); + assert!(stdout.contains("GET /")); + assert!(stdout.contains("Graceful shutdown completed")); + Ok(()) +} + +fn main() { + let names = ["run", "run_fail", "infinite"]; + let tests = [run, run_fail, infinite]; + let handles = tests.into_iter() + .map(|test| std::thread::spawn(test)) + .collect::>(); + + let mut failure = false; + for (handle, name) in handles.into_iter().zip(names) { + let result = handle.join(); + failure = failure || matches!(result, Ok(Err(_)) | Err(_)); + match result { + Ok(Ok(_)) => continue, + Ok(Err(e)) => eprintln!("{name} failed: {e}"), + Err(_) => eprintln!("{name} failed (see panic above)"), + } + } + + if failure { + std::process::exit(1); + } +}