diff --git a/core/codegen/src/attribute/async_entry.rs b/core/codegen/src/attribute/async_entry.rs index dfbe8b2c..b42a3dbc 100644 --- a/core/codegen/src/attribute/async_entry.rs +++ b/core/codegen/src/attribute/async_entry.rs @@ -7,7 +7,7 @@ trait EntryAttr { const REQUIRES_ASYNC: bool; /// Return a new or rewritten function, using block as the main execution. - fn function(f: &syn::ItemFn, body: &syn::Block) -> Result; + fn function(f: &mut syn::ItemFn) -> Result; } struct Main; @@ -15,8 +15,8 @@ struct Main; impl EntryAttr for Main { const REQUIRES_ASYNC: bool = true; - fn function(f: &syn::ItemFn, block: &syn::Block) -> Result { - let (attrs, vis, mut sig) = (&f.attrs, &f.vis, f.sig.clone()); + fn function(f: &mut syn::ItemFn) -> Result { + let (attrs, vis, block, sig) = (&f.attrs, &f.vis, &f.block, &mut f.sig); if sig.ident != "main" { // FIXME(diag): warning! Span::call_site() @@ -37,8 +37,8 @@ struct Test; impl EntryAttr for Test { const REQUIRES_ASYNC: bool = true; - fn function(f: &syn::ItemFn, block: &syn::Block) -> Result { - let (attrs, vis, mut sig) = (&f.attrs, &f.vis, f.sig.clone()); + fn function(f: &mut syn::ItemFn) -> Result { + let (attrs, vis, block, sig) = (&f.attrs, &f.vis, &f.block, &mut f.sig); sig.asyncness = None; Ok(quote_spanned!(block.span().into() => #(#attrs)* #[test] #vis #sig { ::rocket::async_test(async move #block) @@ -51,7 +51,7 @@ struct Launch; impl EntryAttr for Launch { const REQUIRES_ASYNC: bool = false; - fn function(f: &syn::ItemFn, block: &syn::Block) -> Result { + fn function(f: &mut syn::ItemFn) -> Result { if f.sig.ident == "main" { return Err(Span::call_site() .error("attribute cannot be applied to `main` function") @@ -59,6 +59,14 @@ impl EntryAttr for Launch { .span_note(f.sig.ident.span(), "this function cannot be `main`")); } + // Always infer the type as `::rocket::Rocket`. + if let syn::ReturnType::Type(_, ref mut ty) = &mut f.sig.output { + if let syn::Type::Infer(_) = &mut **ty { + let new = quote_spanned!(ty.span() => ::rocket::Rocket); + *ty = syn::parse2(new).expect("path is type"); + } + } + let ty = match &f.sig.output { syn::ReturnType::Type(_, ty) => ty, _ => return Err(Span::call_site() @@ -66,13 +74,13 @@ impl EntryAttr for Launch { .span_note(f.sig.span(), "this function must return a value")) }; + let block = &f.block; let rocket = quote_spanned!(ty.span().into() => { let ___rocket: #ty = #block; let ___rocket: ::rocket::Rocket = ___rocket; ___rocket }); - // FIXME: Don't duplicate the `#block` here! let (vis, mut sig) = (&f.vis, f.sig.clone()); sig.ident = syn::Ident::new("main", sig.ident.span()); sig.output = syn::ReturnType::Default; @@ -112,8 +120,8 @@ fn _async_entry( _args: proc_macro::TokenStream, input: proc_macro::TokenStream ) -> Result { - let function = parse_input::(input)?; - A::function(&function, &function.block).map(|t| t.into()) + let mut function = parse_input::(input)?; + A::function(&mut function).map(|t| t.into()) } macro_rules! async_entry { diff --git a/core/codegen/tests/async-entry.rs b/core/codegen/tests/async-entry.rs index c4851c69..d64b5240 100644 --- a/core/codegen/tests/async-entry.rs +++ b/core/codegen/tests/async-entry.rs @@ -25,6 +25,15 @@ mod b { } } +mod b_inferred { + #[rocket::launch] + async fn main2() -> _ { rocket::ignite() } + + async fn use_it() { + let rocket: rocket::Rocket = main2().await; + } +} + mod c { // non-async launch. #[rocket::launch] @@ -37,6 +46,15 @@ mod c { } } +mod c_inferred { + #[rocket::launch] + fn rocket() -> _ { rocket::ignite() } + + fn use_it() { + let rocket: rocket::Rocket = rocket(); + } +} + mod d { // main with async, is async. #[rocket::main]