Browse Source

Merge pull request #513 from PixiEditor/code-gen-fix

Fixed code generation
Krzysztof Krysiński 2 years ago
parent
commit
dd3fc5e1e6
1 changed files with 112 additions and 122 deletions
  1. 112 122
      src/PixiEditorGen/CommandNameListGenerator.cs

+ 112 - 122
src/PixiEditorGen/CommandNameListGenerator.cs

@@ -1,5 +1,6 @@
-using System.Text;
+using System.Collections.Immutable;
 using Microsoft.CodeAnalysis;
 using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
 using Microsoft.CodeAnalysis.CSharp.Syntax;
 using Microsoft.CodeAnalysis.CSharp.Syntax;
 
 
 namespace PixiEditorGen;
 namespace PixiEditorGen;
@@ -7,171 +8,131 @@ namespace PixiEditorGen;
 [Generator(LanguageNames.CSharp)]
 [Generator(LanguageNames.CSharp)]
 public class CommandNameListGenerator : IIncrementalGenerator
 public class CommandNameListGenerator : IIncrementalGenerator
 {
 {
-    private const string Command = "PixiEditor.Models.Commands.Attributes.Commands";
+    private const string Commands = "PixiEditor.Models.Commands.Attributes.Commands";
 
 
     private const string Evaluators = "PixiEditor.Models.Commands.Attributes.Evaluators.Evaluator";
     private const string Evaluators = "PixiEditor.Models.Commands.Attributes.Evaluators.Evaluator";
 
 
     private const string Groups = "PixiEditor.Models.Commands.Attributes.Commands.Command.GroupAttribute";
     private const string Groups = "PixiEditor.Models.Commands.Attributes.Commands.Command.GroupAttribute";
-    
+
     public void Initialize(IncrementalGeneratorInitializationContext context)
     public void Initialize(IncrementalGeneratorInitializationContext context)
     {
     {
-        var commandList = context.SyntaxProvider.CreateSyntaxProvider(
-            (x, token) =>
-        {
-            return x is MethodDeclarationSyntax method && method.AttributeLists.Count > 0;
-        }, static (context, cancelToken) =>
-        {
-            var method = (MethodDeclarationSyntax)context.Node;
+        var commandList = CreateSyntaxProvider<Command>(context, Commands).Where(x => x != null);
+        var evaluatorList = CreateSyntaxProvider<Command>(context, Evaluators).Where(x => x != null);
+        var groupList = CreateSyntaxProvider<Group>(context, Groups).Where(x => x != null);
 
 
-            if (!HasCommandAttribute(method, context, cancelToken, Command))
-                return (null, null, null);
-
-            var symbol = context.SemanticModel.GetDeclaredSymbol(method, cancelToken);
-
-            if (symbol is IMethodSymbol methodSymbol)
-            {
-                if (methodSymbol.ReceiverType == null)
-                    return (null, null, null);
-                
-                return (methodSymbol.ReceiverType.ToDisplayString(), methodSymbol.Name, methodSymbol.Parameters.Select(x => x.ToDisplayString()));
-            }
-            else
-            {
-                return (null, null, null);
-            }
-        }).Where(x => x.Item1 != null);
+        context.RegisterSourceOutput(commandList.Collect(), (context, commands) => AddSource(context, commands, "Commands"));
+        context.RegisterSourceOutput(evaluatorList.Collect(), (context, evaluators) => AddSource(context, evaluators, "Evaluators"));
+        context.RegisterSourceOutput(groupList.Collect(), AddGroupsSource);
+    }
 
 
-        var evaluatorList = context.SyntaxProvider.CreateSyntaxProvider(
+    private IncrementalValuesProvider<T?> CreateSyntaxProvider<T>(IncrementalGeneratorInitializationContext context, string className) where T : CommandMember<T>
+    {
+        return context.SyntaxProvider.CreateSyntaxProvider(
             (x, token) =>
             (x, token) =>
             {
             {
-                return x is MethodDeclarationSyntax method && method.AttributeLists.Count > 0;
-            }, static (context, cancelToken) =>
-            {
-                var method = (MethodDeclarationSyntax)context.Node;
-
-                if (!HasCommandAttribute(method, context, cancelToken, Evaluators))
-                    return (null, null, null);
-
-                var symbol = context.SemanticModel.GetDeclaredSymbol(method, cancelToken);
-
-                if (symbol is IMethodSymbol methodSymbol)
+                if (typeof(T) == typeof(Command))
                 {
                 {
-                    return (methodSymbol.ReceiverType.ToDisplayString(), methodSymbol.Name, methodSymbol.Parameters.Select(x => x.ToDisplayString()));
+                    return x is MethodDeclarationSyntax method && method.AttributeLists.Count > 0;
                 }
                 }
                 else
                 else
                 {
                 {
-                    return (null, null, null);
+                    return x is TypeDeclarationSyntax type && type.AttributeLists.Count > 0;
                 }
                 }
-            }).Where(x => x.Item1 != null);
-        
-        var groupList = context.SyntaxProvider.CreateSyntaxProvider(
-            (x, token) =>
-            {
-                return x is TypeDeclarationSyntax type && type.AttributeLists.Count > 0;
-            }, static (context, cancelToken) =>
+            }, (context, cancelToken) =>
             {
             {
-                var method = (TypeDeclarationSyntax)context.Node;
+                var member = (MemberDeclarationSyntax)context.Node;
 
 
-                if (!HasCommandAttribute(method, context, cancelToken, Groups))
+                if (!HasCommandAttribute(member, context, cancelToken, className))
                     return null;
                     return null;
 
 
-                var symbol = context.SemanticModel.GetDeclaredSymbol(method, cancelToken);
+                var symbol = context.SemanticModel.GetDeclaredSymbol(member, cancelToken);
+
+                if (symbol is IMethodSymbol methodSymbol && typeof(T) == typeof(Command))
+                {
+                    if (methodSymbol.ReceiverType == null)
+                        return null;
 
 
-                if (symbol is ITypeSymbol methodSymbol)
+                    return (T)(object)new Command(methodSymbol);
+                }
+                else if (symbol is ITypeSymbol typeSymbol && typeof(T) == typeof(Group))
                 {
                 {
-                    return methodSymbol.ToDisplayString();
+                    return (T)(object)new Group(typeSymbol);
                 }
                 }
                 else
                 else
                 {
                 {
                     return null;
                     return null;
                 }
                 }
-            }).Where(x => x != null);
-
-        context.RegisterSourceOutput(commandList.Collect(), static (context, methodNames) =>
-        {
-            var code = new StringBuilder(
-                @"namespace PixiEditor.Models.Commands;
-
-internal partial class CommandNameList {
-    partial void AddCommands() {");
+            });
+    }
 
 
-            List<string> createdClasses = new List<string>();
+    private void AddSource(SourceProductionContext context, ImmutableArray<Command> methodNames, string name)
+    {
+        List<string> createdClasses = new List<string>();
+        SyntaxList<StatementSyntax> statements = new SyntaxList<StatementSyntax>();
 
 
-            foreach (var method in methodNames)
+        foreach (var methodName in methodNames)
+        {
+            if (!createdClasses.Contains(methodName.OwnerTypeName))
             {
             {
-                if (!createdClasses.Contains(method.Item1))
-                {
-                    code.AppendLine($"      Commands.Add(typeof({method.Item1}), new());");
-                    createdClasses.Add(method.Item1);
-                }
-
-                var parameters = string.Join(",", method.Item3.Select(x => $"typeof({x})"));
-                
-                bool hasParameters = parameters.Length > 0;
-                string paramString = hasParameters ? $"new Type[] {{ {parameters} }}" : "Array.Empty<Type>()";
-                
-                code.AppendLine($"      Commands[typeof({method.Item1})].Add((\"{method.Item2}\", {paramString}));");
+                statements = statements.Add(SyntaxFactory.ParseStatement($"{name}.Add(typeof({methodName.OwnerTypeName}), new());"));
+                createdClasses.Add(methodName.OwnerTypeName);
             }
             }
 
 
-            code.Append("   }\n}");
+            var parameters = string.Join(",", methodName.ParameterTypeNames);
+            string paramString = parameters.Length > 0 ? $"new Type[] {{ {parameters} }}" : "Array.Empty<Type>()";
 
 
-            context.AddSource("CommandNameList+Commands", code.ToString());
-        });
-        
-        context.RegisterSourceOutput(evaluatorList.Collect(), static (context, methodNames) =>
-        {
-            var code = new StringBuilder(
-                @"namespace PixiEditor.Models.Commands;
+            statements = statements.Add(SyntaxFactory.ParseStatement($"{name}[typeof({methodName.OwnerTypeName})].Add((\"{methodName.MethodName}\", {paramString}));"));
+        }
 
 
-internal partial class CommandNameList {
-    partial void AddEvaluators() {");
+        // partial void Add$name$()
+        var method = SyntaxFactory
+            .MethodDeclaration(SyntaxFactory.ParseTypeName("void"), $"Add{name}")
+            .AddModifiers(SyntaxFactory.Token(SyntaxKind.PartialKeyword))
+            .WithBody(SyntaxFactory.Block(statements));
 
 
-            List<string> createdClasses = new List<string>();
+        // internal partial class CommandNameList
+        var cDecl = SyntaxFactory
+            .ClassDeclaration("CommandNameList")
+            .AddModifiers(SyntaxFactory.Token(SyntaxKind.InternalKeyword), SyntaxFactory.Token(SyntaxKind.PartialKeyword))
+            .AddMembers(method);
 
 
-            foreach (var method in methodNames)
-            {
-                if (!createdClasses.Contains(method.Item1))
-                {
-                    code.AppendLine($"      Evaluators.Add(typeof({method.Item1}), new());");
-                    createdClasses.Add(method.Item1);
-                }
+        // namespace PixiEditor.Models.Commands
+        var nspace = SyntaxFactory
+            .NamespaceDeclaration(SyntaxFactory.ParseName("PixiEditor.Models.Commands"))
+            .AddMembers(cDecl);
 
 
-                if (method.Item3 == null || !method.Item3.Any())
-                {
-                    code.AppendLine($"      Evaluators[typeof({method.Item1})].Add((\"{method.Item2}\", Array.Empty<Type>()));");
-                }
-                else
-                {
-                    var parameters = string.Join(",", method.Item3.Select(x => $"typeof({x})"));
-                    string paramString = parameters.Length > 0 ? $"new Type[] {{ {parameters} }}" : "Array.Empty<Type>()";
-                    code.AppendLine($"      Evaluators[typeof({method.Item1})].Add((\"{method.Item2}\", {paramString}));");
-                }
-            }
+        context.AddSource($"CommandNameList+{name}", nspace.NormalizeWhitespace().ToFullString());
+    }
 
 
-            code.Append("   }\n}");
+    private void AddGroupsSource(SourceProductionContext context, ImmutableArray<Group> groups)
+    {
+        SyntaxList<StatementSyntax> statements = new SyntaxList<StatementSyntax>();
 
 
-            context.AddSource("CommandNameList+Evaluators", code.ToString());
-        });
-        
-        context.RegisterSourceOutput(groupList.Collect(), static (context, typeNames) =>
+        foreach (var group in groups)
         {
         {
-            var code = new StringBuilder(
-                @"namespace PixiEditor.Models.Commands;
+            statements = statements.Add(SyntaxFactory.ParseStatement($"Groups.Add(typeof({group.OwnerTypeName}));"));
+        }
 
 
-internal partial class CommandNameList {
-    partial void AddGroups() {");
+        // partial void AddGroups()
+        var method = SyntaxFactory
+            .MethodDeclaration(SyntaxFactory.ParseTypeName("void"), "AddGroups")
+            .AddModifiers(SyntaxFactory.Token(SyntaxKind.PartialKeyword))
+            .WithBody(SyntaxFactory.Block(statements));
 
 
-            foreach (var name in typeNames)
-            {
-                code.AppendLine($"      Groups.Add(typeof({name}));");
-            }
+        // internal partial class CommandNameList
+        var cDecl = SyntaxFactory
+            .ClassDeclaration("CommandNameList")
+            .AddModifiers(SyntaxFactory.Token(SyntaxKind.InternalKeyword), SyntaxFactory.Token(SyntaxKind.PartialKeyword))
+            .AddMembers(method);
 
 
-            code.Append("   }\n}");
+        // namespace PixiEditor.Models.Commands
+        var nspace = SyntaxFactory
+            .NamespaceDeclaration(SyntaxFactory.ParseName("PixiEditor.Models.Commands"))
+            .AddMembers(cDecl);
 
 
-            context.AddSource("CommandNameList+Groups", code.ToString());
-        });
+        context.AddSource("CommandNameList+Groups", nspace.NormalizeWhitespace().ToFullString());
     }
     }
-    
+
     private static bool HasCommandAttribute(MemberDeclarationSyntax declaration, GeneratorSyntaxContext context, CancellationToken token, string commandAttributeStart)
     private static bool HasCommandAttribute(MemberDeclarationSyntax declaration, GeneratorSyntaxContext context, CancellationToken token, string commandAttributeStart)
     {
     {
         foreach (var attrList in declaration.AttributeLists)
         foreach (var attrList in declaration.AttributeLists)
@@ -191,4 +152,33 @@ internal partial class CommandNameList {
 
 
         return false;
         return false;
     }
     }
+
+    class CommandMember<TSelf> where TSelf : CommandMember<TSelf>
+    {
+        public string OwnerTypeName { get; }
+
+        public CommandMember(string ownerTypeName)
+        {
+            OwnerTypeName = ownerTypeName;
+        }
+    }
+
+    class Command : CommandMember<Command>
+    {
+        public string MethodName { get; }
+
+        public string[] ParameterTypeNames { get; }
+
+        public Command(IMethodSymbol symbol) : base(symbol.ContainingType.ToDisplayString())
+        {
+            MethodName = symbol.Name;
+            ParameterTypeNames = symbol.Parameters.Select(x => $"typeof({x.Type.ToDisplayString()})").ToArray();
+        }
+    }
+
+    class Group : CommandMember<Group>
+    {
+        public Group(ITypeSymbol symbol) : base(symbol.ToDisplayString())
+        { }
+    }
 }
 }