godot/modules/mono/glue/GodotSharp/Godot.SourceGenerators.Internal/UnmanagedCallbacksGenerator.cs
Ignacio Roldán Etcheverry 2c180f62d9 C#: Replace P/Invoke with delegate pointers
- Moves interop functions to UnmanagedCallbacks struct that
  contains the function pointers and is passed to C#.

- Implements UnmanagedCallbacksGenerator, a C# source generator that
  generates the UnmanagedCallbacks struct in C# and the body for the
  NativeFuncs methods (their implementation just calls the function
  pointer in the UnmanagedCallbacks). The generated methods are needed
  because .NET pins byref parameters of native calls, even if they are
  'ref struct's, which don't need pinning. The generated methods use
  `Unsafe.AsPointer` so that we can benefit from byref parameters
  without suffering overhead of pinning.

Co-authored-by: Raul Santos <raulsntos@gmail.com>
2022-08-22 03:36:52 +02:00

464 lines
16 KiB
C#

using System.Text;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Text;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace Godot.SourceGenerators.Internal;
[Generator]
public class UnmanagedCallbacksGenerator : ISourceGenerator
{
public void Initialize(GeneratorInitializationContext context)
{
context.RegisterForPostInitialization(ctx => { GenerateAttribute(ctx); });
}
public void Execute(GeneratorExecutionContext context)
{
INamedTypeSymbol[] unmanagedCallbacksClasses = context
.Compilation.SyntaxTrees
.SelectMany(tree =>
tree.GetRoot().DescendantNodes()
.OfType<ClassDeclarationSyntax>()
.SelectUnmanagedCallbacksClasses(context.Compilation)
// Report and skip non-partial classes
.Where(x =>
{
if (x.cds.IsPartial())
{
if (x.cds.IsNested() && !x.cds.AreAllOuterTypesPartial(out var typeMissingPartial))
{
Common.ReportNonPartialUnmanagedCallbacksOuterClass(context, typeMissingPartial!);
return false;
}
return true;
}
Common.ReportNonPartialUnmanagedCallbacksClass(context, x.cds, x.symbol);
return false;
})
.Select(x => x.symbol)
)
.Distinct<INamedTypeSymbol>(SymbolEqualityComparer.Default)
.ToArray();
foreach (var symbol in unmanagedCallbacksClasses)
{
var attr = symbol.GetGenerateUnmanagedCallbacksAttribute();
if (attr == null || attr.ConstructorArguments.Length != 1)
{
// TODO: Report error or throw exception, this is an invalid case and should never be reached
System.Diagnostics.Debug.Fail("FAILED!");
continue;
}
var funcStructType = (INamedTypeSymbol?)attr.ConstructorArguments[0].Value;
if (funcStructType == null)
{
// TODO: Report error or throw exception, this is an invalid case and should never be reached
System.Diagnostics.Debug.Fail("FAILED!");
continue;
}
var data = new CallbacksData(symbol, funcStructType);
GenerateInteropMethodImplementations(context, data);
GenerateUnmanagedCallbacksStruct(context, data);
}
}
private void GenerateAttribute(GeneratorPostInitializationContext context)
{
string source = @"using System;
namespace Godot.SourceGenerators.Internal
{
internal class GenerateUnmanagedCallbacksAttribute : Attribute
{
public Type FuncStructType { get; }
public GenerateUnmanagedCallbacksAttribute(Type funcStructType)
{
FuncStructType = funcStructType;
}
}
}";
context.AddSource("GenerateUnmanagedCallbacksAttribute.generated",
SourceText.From(source, Encoding.UTF8));
}
private void GenerateInteropMethodImplementations(GeneratorExecutionContext context, CallbacksData data)
{
var symbol = data.NativeTypeSymbol;
INamespaceSymbol namespaceSymbol = symbol.ContainingNamespace;
string classNs = namespaceSymbol != null && !namespaceSymbol.IsGlobalNamespace ?
namespaceSymbol.FullQualifiedName() :
string.Empty;
bool hasNamespace = classNs.Length != 0;
bool isInnerClass = symbol.ContainingType != null;
var source = new StringBuilder();
var methodSource = new StringBuilder();
var methodCallArguments = new StringBuilder();
var methodSourceAfterCall = new StringBuilder();
source.Append(
@"using System;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Godot.Bridge;
using Godot.NativeInterop;
#pragma warning disable CA1707 // Disable warning: Identifiers should not contain underscores
");
if (hasNamespace)
{
source.Append("namespace ");
source.Append(classNs);
source.Append("\n{\n");
}
if (isInnerClass)
{
var containingType = symbol.ContainingType;
while (containingType != null)
{
source.Append("partial ");
source.Append(containingType.GetDeclarationKeyword());
source.Append(" ");
source.Append(containingType.NameWithTypeParameters());
source.Append("\n{\n");
containingType = containingType.ContainingType;
}
}
source.Append("[System.Runtime.CompilerServices.SkipLocalsInit]\n");
source.Append($"unsafe partial class {symbol.Name}\n");
source.Append("{\n");
source.Append($" private static {data.FuncStructSymbol.FullQualifiedName()} _unmanagedCallbacks;\n\n");
foreach (var callback in data.Methods)
{
methodSource.Clear();
methodCallArguments.Clear();
methodSourceAfterCall.Clear();
source.Append(" [global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]\n");
source.Append($" {SyntaxFacts.GetText(callback.DeclaredAccessibility)} ");
if (callback.IsStatic)
source.Append("static ");
source.Append("partial ");
source.Append(callback.ReturnType.FullQualifiedName());
source.Append(' ');
source.Append(callback.Name);
source.Append('(');
for (int i = 0; i < callback.Parameters.Length; i++)
{
var parameter = callback.Parameters[i];
source.Append(parameter.ToDisplayString());
source.Append(' ');
source.Append(parameter.Name);
if (parameter.RefKind == RefKind.Out)
{
// Only assign default if the parameter won't be passed by-ref or copied later.
if (IsGodotInteropStruct(parameter.Type))
methodSource.Append($" {parameter.Name} = default;\n");
}
if (IsByRefParameter(parameter))
{
if (IsGodotInteropStruct(parameter.Type))
{
methodSource.Append(" ");
AppendCustomUnsafeAsPointer(methodSource, parameter, out string varName);
methodCallArguments.Append(varName);
}
else if (parameter.Type.IsValueType)
{
methodSource.Append(" ");
AppendCopyToStackAndGetPointer(methodSource, parameter, out string varName);
methodCallArguments.Append($"&{varName}");
if (parameter.RefKind is RefKind.Out or RefKind.Ref)
{
methodSourceAfterCall.Append($" {parameter.Name} = {varName};\n");
}
}
else
{
// If it's a by-ref param and we can't get the pointer
// just pass it by-ref and let it be pinned.
AppendRefKind(methodCallArguments, parameter.RefKind)
.Append(' ')
.Append(parameter.Name);
}
}
else
{
methodCallArguments.Append(parameter.Name);
}
if (i < callback.Parameters.Length - 1)
{
source.Append(", ");
methodCallArguments.Append(", ");
}
}
source.Append(")\n");
source.Append(" {\n");
source.Append(methodSource);
source.Append(" ");
if (!callback.ReturnsVoid)
{
if (methodSourceAfterCall.Length != 0)
source.Append($"{callback.ReturnType.FullQualifiedName()} ret = ");
else
source.Append("return ");
}
source.Append($"_unmanagedCallbacks.{callback.Name}(");
source.Append(methodCallArguments);
source.Append(");\n");
if (methodSourceAfterCall.Length != 0)
{
source.Append(methodSourceAfterCall);
if (!callback.ReturnsVoid)
source.Append(" return ret;\n");
}
source.Append(" }\n\n");
}
source.Append("}\n");
if (isInnerClass)
{
var containingType = symbol.ContainingType;
while (containingType != null)
{
source.Append("}\n"); // outer class
containingType = containingType.ContainingType;
}
}
if (hasNamespace)
source.Append("\n}");
source.Append("\n\n#pragma warning restore CA1707\n");
context.AddSource($"{data.NativeTypeSymbol.FullQualifiedName().SanitizeQualifiedNameForUniqueHint()}.generated",
SourceText.From(source.ToString(), Encoding.UTF8));
}
private void GenerateUnmanagedCallbacksStruct(GeneratorExecutionContext context, CallbacksData data)
{
var symbol = data.FuncStructSymbol;
INamespaceSymbol namespaceSymbol = symbol.ContainingNamespace;
string classNs = namespaceSymbol != null && !namespaceSymbol.IsGlobalNamespace ?
namespaceSymbol.FullQualifiedName() :
string.Empty;
bool hasNamespace = classNs.Length != 0;
bool isInnerClass = symbol.ContainingType != null;
var source = new StringBuilder();
source.Append(
@"using System.Runtime.InteropServices;
using Godot.NativeInterop;
#pragma warning disable CA1707 // Disable warning: Identifiers should not contain underscores
");
if (hasNamespace)
{
source.Append("namespace ");
source.Append(classNs);
source.Append("\n{\n");
}
if (isInnerClass)
{
var containingType = symbol.ContainingType;
while (containingType != null)
{
source.Append("partial ");
source.Append(containingType.GetDeclarationKeyword());
source.Append(" ");
source.Append(containingType.NameWithTypeParameters());
source.Append("\n{\n");
containingType = containingType.ContainingType;
}
}
source.Append("[StructLayout(LayoutKind.Sequential)]\n");
source.Append($"unsafe partial struct {symbol.Name}\n{{\n");
foreach (var callback in data.Methods)
{
source.Append(" ");
source.Append(callback.DeclaredAccessibility == Accessibility.Public ? "public " : "internal ");
source.Append("delegate* unmanaged<");
foreach (var parameter in callback.Parameters)
{
if (IsByRefParameter(parameter))
{
if (IsGodotInteropStruct(parameter.Type) || parameter.Type.IsValueType)
{
AppendPointerType(source, parameter.Type);
}
else
{
// If it's a by-ref param and we can't get the pointer
// just pass it by-ref and let it be pinned.
AppendRefKind(source, parameter.RefKind)
.Append(' ')
.Append(parameter.Type.FullQualifiedName());
}
}
else
{
source.Append(parameter.Type.FullQualifiedName());
}
source.Append(", ");
}
source.Append(callback.ReturnType.FullQualifiedName());
source.Append($"> {callback.Name};\n");
}
source.Append("}\n");
if (isInnerClass)
{
var containingType = symbol.ContainingType;
while (containingType != null)
{
source.Append("}\n"); // outer class
containingType = containingType.ContainingType;
}
}
if (hasNamespace)
source.Append("}\n");
source.Append("\n#pragma warning restore CA1707\n");
context.AddSource($"{symbol.FullQualifiedName().SanitizeQualifiedNameForUniqueHint()}.generated",
SourceText.From(source.ToString(), Encoding.UTF8));
}
private static bool IsGodotInteropStruct(ITypeSymbol type) =>
GodotInteropStructs.Contains(type.FullQualifiedName());
private static bool IsByRefParameter(IParameterSymbol parameter) =>
parameter.RefKind is RefKind.In or RefKind.Out or RefKind.Ref;
private static StringBuilder AppendRefKind(StringBuilder source, RefKind refKind) =>
refKind switch
{
RefKind.In => source.Append("in"),
RefKind.Out => source.Append("out"),
RefKind.Ref => source.Append("ref"),
_ => source,
};
private static void AppendPointerType(StringBuilder source, ITypeSymbol type)
{
source.Append(type.FullQualifiedName());
source.Append('*');
}
private static void AppendCustomUnsafeAsPointer(StringBuilder source, IParameterSymbol parameter,
out string varName)
{
varName = $"{parameter.Name}_ptr";
AppendPointerType(source, parameter.Type);
source.Append(' ');
source.Append(varName);
source.Append(" = ");
source.Append('(');
AppendPointerType(source, parameter.Type);
source.Append(')');
if (parameter.RefKind == RefKind.In)
source.Append("CustomUnsafe.ReadOnlyRefAsPointer(in ");
else
source.Append("CustomUnsafe.AsPointer(ref ");
source.Append(parameter.Name);
source.Append(");\n");
}
private static void AppendCopyToStackAndGetPointer(StringBuilder source, IParameterSymbol parameter,
out string varName)
{
varName = $"{parameter.Name}_copy";
source.Append(parameter.Type.FullQualifiedName());
source.Append(' ');
source.Append(varName);
if (parameter.RefKind is RefKind.In or RefKind.Ref)
{
source.Append(" = ");
source.Append(parameter.Name);
}
source.Append(";\n");
}
private static readonly string[] GodotInteropStructs =
{
"Godot.NativeInterop.godot_ref",
"Godot.NativeInterop.godot_variant_call_error",
"Godot.NativeInterop.godot_variant",
"Godot.NativeInterop.godot_string",
"Godot.NativeInterop.godot_string_name",
"Godot.NativeInterop.godot_node_path",
"Godot.NativeInterop.godot_signal",
"Godot.NativeInterop.godot_callable",
"Godot.NativeInterop.godot_array",
"Godot.NativeInterop.godot_dictionary",
"Godot.NativeInterop.godot_packed_byte_array",
"Godot.NativeInterop.godot_packed_int32_array",
"Godot.NativeInterop.godot_packed_int64_array",
"Godot.NativeInterop.godot_packed_float32_array",
"Godot.NativeInterop.godot_packed_float64_array",
"Godot.NativeInterop.godot_packed_string_array",
"Godot.NativeInterop.godot_packed_vector2_array",
"Godot.NativeInterop.godot_packed_vector3_array",
"Godot.NativeInterop.godot_packed_color_array",
};
}