ApiGenerator.cs 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. using System.Collections.Immutable;
  2. using System.Diagnostics;
  3. using Microsoft.CodeAnalysis;
  4. using Microsoft.CodeAnalysis.CSharp;
  5. using Microsoft.CodeAnalysis.CSharp.Syntax;
  6. namespace PixiEditor.WasmApi.Gen;
  7. [Generator(LanguageNames.CSharp)]
  8. public class ApiGenerator : IIncrementalGenerator
  9. {
  10. private const string FullyQualifiedApiFunctionAttributeName =
  11. "PixiEditor.Extensions.WasmRuntime.ApiFunctionAttribute";
  12. private const string ApiFunctionAttributeName = "ApiFunctionAttribute";
  13. public void Initialize(IncrementalGeneratorInitializationContext context)
  14. {
  15. var methods = context.SyntaxProvider.ForAttributeWithMetadataName(
  16. FullyQualifiedApiFunctionAttributeName,
  17. (_, _) => true,
  18. GetApiFunctionMethodOrNull)
  19. .Where(x => x is not null)
  20. .Collect();
  21. context.RegisterSourceOutput(methods, GenerateLinkerCode);
  22. }
  23. private void GenerateLinkerCode(SourceProductionContext ctx, ImmutableArray<(IMethodSymbol methodSymbol, SemanticModel SemanticModel)?> symbols)
  24. {
  25. List<StatementSyntax> linkingMethodsCode = new List<StatementSyntax>();
  26. foreach (var symbol in symbols)
  27. {
  28. if(!symbol.HasValue) continue;
  29. if (symbol.Value.methodSymbol == null) continue;
  30. linkingMethodsCode.Add(GenerateLinkingCodeForMethod(symbol.Value));
  31. }
  32. // partial void LinkApiFunctions()
  33. var methodDeclaration = SyntaxFactory
  34. .MethodDeclaration(SyntaxFactory.ParseTypeName("void"), $"LinkApiFunctions")
  35. .AddModifiers(SyntaxFactory.Token(SyntaxKind.PartialKeyword))
  36. .WithBody(SyntaxFactory.Block(linkingMethodsCode));
  37. // internal partial class WasmExtensionInstance
  38. var cDecl = SyntaxFactory
  39. .ClassDeclaration("WasmExtensionInstance")
  40. .AddModifiers(SyntaxFactory.Token(SyntaxKind.PublicKeyword), SyntaxFactory.Token(SyntaxKind.PartialKeyword))
  41. .AddMembers(methodDeclaration);
  42. // namespace PixiEditor.Extensions.WasmRuntime
  43. var nspace = SyntaxFactory
  44. .NamespaceDeclaration(SyntaxFactory.ParseName("PixiEditor.Extensions.WasmRuntime"))
  45. .AddMembers(cDecl);
  46. ctx.AddSource($"WasmExtensionInstance+ApiFunctions", nspace.NormalizeWhitespace().ToFullString());
  47. }
  48. private StatementSyntax GenerateLinkingCodeForMethod((IMethodSymbol methodSymbol, SemanticModel SemanticModel) symbol)
  49. {
  50. string name = $"{symbol.methodSymbol.GetAttributes()[0].ConstructorArguments[0].ToCSharpString()}";
  51. ImmutableArray<IParameterSymbol> arguments = symbol.methodSymbol.Parameters;
  52. List<string> convertedParams = new List<string>();
  53. foreach (var argSymbol in arguments)
  54. {
  55. convertedParams.AddRange(TypeConversionTable.ConvertTypeToFunctionParams(argSymbol));
  56. }
  57. ParameterListSyntax paramList = SyntaxFactory.ParseParameterList(string.Join(",", convertedParams));
  58. SyntaxList<StatementSyntax> statements = new SyntaxList<StatementSyntax>();
  59. SyntaxList<StatementSyntax> variableStatements = BuildVariableStatements(arguments);
  60. statements = statements.AddRange(variableStatements);
  61. statements = statements.AddRange(BuildFunctionBody(symbol));
  62. BlockSyntax body = SyntaxFactory.Block(statements);
  63. var parameters = SyntaxFactory.ParameterList(paramList.Parameters);
  64. var define = SyntaxFactory.ParseStatement(
  65. $"Linker.DefineFunction(\"env\", {name}, {parameters.ToFullString()} => \n{body.ToFullString()});");
  66. return define;
  67. }
  68. private SyntaxList<StatementSyntax> BuildVariableStatements(ImmutableArray<IParameterSymbol> arguments)
  69. {
  70. SyntaxList<StatementSyntax> syntaxes = new SyntaxList<StatementSyntax>();
  71. foreach (var argSymbol in arguments)
  72. {
  73. // For some reason, int, double are passed as is, not as a pointer
  74. if (!TypeConversionTable.IsValuePassableType(argSymbol.Type, out _))
  75. {
  76. string lowerType = argSymbol.Type.Name;
  77. bool isLengthType = TypeConversionTable.IsLengthType(argSymbol);
  78. string paramsString = isLengthType
  79. ? $"{argSymbol.Name}Pointer, {argSymbol.Name}Length"
  80. : $"{argSymbol.Name}Pointer";
  81. syntaxes = syntaxes.Add(SyntaxFactory.ParseStatement(
  82. $"{argSymbol.Type.ToDisplayString()} {argSymbol.Name} = WasmMemoryUtility.Get{lowerType}({paramsString});"));
  83. }
  84. }
  85. return syntaxes;
  86. }
  87. private SyntaxList<StatementSyntax> BuildFunctionBody((IMethodSymbol methodSymbol, SemanticModel SemanticModel) method)
  88. {
  89. SyntaxList<StatementSyntax> syntaxes = new SyntaxList<StatementSyntax>();
  90. MethodBodyRewriter rewriter = new MethodBodyRewriter(method.SemanticModel);
  91. foreach (SyntaxReference? reference in method.methodSymbol.DeclaringSyntaxReferences)
  92. {
  93. SyntaxNode? node = reference.GetSyntax();
  94. if (node is not MethodDeclarationSyntax methodDeclaration)
  95. continue;
  96. var statements = methodDeclaration.Body!.Statements;
  97. foreach (var statement in statements)
  98. {
  99. if(statement is not ReturnStatementSyntax returnStatementSyntax)
  100. {
  101. var newStatement = (StatementSyntax)rewriter.Visit(statement);
  102. syntaxes = syntaxes.Add(newStatement);
  103. }
  104. else
  105. {
  106. var returnType = method.methodSymbol.ReturnType.Name;
  107. string statementString =
  108. $"return WasmMemoryUtility.Write{returnType}({returnStatementSyntax.Expression.ToFullString()});";
  109. if (TypeConversionTable.IsValuePassableType(method.methodSymbol.ReturnType, out _))
  110. {
  111. statementString = $"return {returnStatementSyntax.Expression.ToFullString()};";
  112. }
  113. syntaxes = syntaxes.Add(SyntaxFactory.ParseStatement(statementString));
  114. }
  115. }
  116. }
  117. return syntaxes;
  118. }
  119. private static (IMethodSymbol methodSymbol, SemanticModel SemanticModel)? GetApiFunctionMethodOrNull(GeneratorAttributeSyntaxContext context,
  120. CancellationToken cancelToken)
  121. {
  122. if (context.TargetSymbol is not IMethodSymbol methodSymbol)
  123. return null;
  124. return (methodSymbol, context.SemanticModel);
  125. }
  126. }