Warn if a task is spawned in a sync '#[launch]'.

The warning is fairly conservative. Heuristics are used to determine if a call
to `tokio::spawn()` occurs in the `#[launch]` function.

Addresses #2547.
This commit is contained in:
Sergio Benitez 2023-05-24 11:33:56 -07:00
parent 5024dbf694
commit 4b7d48967b
1 changed files with 53 additions and 0 deletions

View File

@ -9,6 +9,48 @@ use proc_macro2::{TokenStream, Span};
/// returned instance inside of an `rocket::async_main`. /// returned instance inside of an `rocket::async_main`.
pub struct Launch; pub struct Launch;
/// Determines if `f` likely spawns an async task, returning the spawn call.
fn likely_spawns(f: &syn::ItemFn) -> Option<&syn::ExprCall> {
use syn::visit::{self, Visit};
struct SpawnFinder<'a>(Option<&'a syn::ExprCall>);
impl<'ast> Visit<'ast> for SpawnFinder<'ast> {
fn visit_expr_call(&mut self, i: &'ast syn::ExprCall) {
if self.0.is_some() {
return;
}
if let syn::Expr::Path(ref e) = *i.func {
let mut segments = e.path.segments.clone();
if let Some(last) = segments.pop() {
if last.value().ident != "spawn" {
return visit::visit_expr_call(self, i);
}
if let Some(prefix) = segments.pop() {
if prefix.value().ident == "tokio" {
self.0 = Some(i);
return;
}
}
if let Some(syn::Expr::Async(_)) = i.args.first() {
self.0 = Some(i);
return;
}
}
};
visit::visit_expr_call(self, i);
}
}
let mut v = SpawnFinder(None);
v.visit_item_fn(f);
v.0
}
impl EntryAttr for Launch { impl EntryAttr for Launch {
const REQUIRES_ASYNC: bool = false; const REQUIRES_ASYNC: bool = false;
@ -47,6 +89,17 @@ impl EntryAttr for Launch {
None => quote_spanned!(ty.span() => #rocket.launch()), None => quote_spanned!(ty.span() => #rocket.launch()),
}; };
if f.sig.asyncness.is_none() {
if let Some(call) = likely_spawns(f) {
call.span()
.warning("task is being spawned outside an async context")
.span_help(f.sig.span(), "declare this function as `async fn` \
to require async execution")
.span_note(Span::call_site(), "`#[launch]` call is here")
.emit_as_expr_tokens();
}
}
let (vis, mut sig) = (&f.vis, f.sig.clone()); let (vis, mut sig) = (&f.vis, f.sig.clone());
sig.ident = syn::Ident::new("main", sig.ident.span()); sig.ident = syn::Ident::new("main", sig.ident.span());
sig.output = syn::ReturnType::Default; sig.output = syn::ReturnType::Default;