Browse Source

Add: type validation

AnnulusGames 1 year ago
parent
commit
2740013690

+ 8 - 0
src/Lua.SourceGenerator/DiagnosticDescriptors.cs

@@ -37,4 +37,12 @@ public static class DiagnosticDescriptors
         category: Category,
         defaultSeverity: DiagnosticSeverity.Error,
         isEnabledByDefault: true);
+
+    public static readonly DiagnosticDescriptor InvalidMethodType = new(
+        id: "LUAC005",
+        title: "The arguments and return types must be LuaValue or types that can be converted to LuaValue.",
+        messageFormat: "The arguments and return types must be LuaValue or types that can be converted to LuaValue.",
+        category: Category,
+        defaultSeverity: DiagnosticSeverity.Error,
+        isEnabledByDefault: true);
 }

+ 53 - 1
src/Lua.SourceGenerator/LuaObjectGenerator.Emit.cs

@@ -1,10 +1,11 @@
 using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
 
 namespace Lua.SourceGenerator;
 
 partial class LuaObjectGenerator
 {
-    static bool TryEmit(TypeMetadata typeMetadata, CodeBuilder builder, in SourceProductionContext context)
+    static bool TryEmit(TypeMetadata typeMetadata, CodeBuilder builder, SymbolReferences references, Compilation compilation, in SourceProductionContext context)
     {
         try
         {
@@ -40,6 +41,11 @@ partial class LuaObjectGenerator
                 error = true;
             }
 
+            if (!ValidateMembers(typeMetadata, compilation, references, context))
+            {
+                error = true;
+            }
+
             if (error)
             {
                 return false;
@@ -133,6 +139,52 @@ $$"""
         }
     }
 
+    static bool ValidateMembers(TypeMetadata typeMetadata, Compilation compilation, SymbolReferences references, in SourceProductionContext context)
+    {
+        var error = true;
+
+        foreach (var property in typeMetadata.Properties)
+        {
+            if (SymbolEqualityComparer.Default.Equals(property.Type, references.LuaValue)) continue;
+            if (SymbolEqualityComparer.Default.Equals(property.Type, typeMetadata.Symbol)) continue;
+
+            var conversion = compilation.ClassifyConversion(property.Type, references.LuaValue);
+            if (!conversion.Exists)
+            {
+                context.ReportDiagnostic(Diagnostic.Create(
+                    DiagnosticDescriptors.InvalidPropertyType,
+                    property.Symbol.Locations.FirstOrDefault(),
+                    property.Type.Name));
+
+                error = false;
+            }
+        }
+
+        foreach (var method in typeMetadata.Methods)
+        {
+            foreach (var typeSymbol in method.Symbol.Parameters
+                .Select(x => x.Type)
+                .Append(method.Symbol.ReturnType))
+            {
+                if (SymbolEqualityComparer.Default.Equals(typeSymbol, references.LuaValue)) continue;
+                if (SymbolEqualityComparer.Default.Equals(typeSymbol, typeMetadata.Symbol)) continue;
+
+                var conversion = compilation.ClassifyConversion(typeSymbol, references.LuaValue);
+                if (!conversion.Exists)
+                {
+                    context.ReportDiagnostic(Diagnostic.Create(
+                        DiagnosticDescriptors.InvalidMethodType,
+                        typeSymbol.Locations.FirstOrDefault(),
+                        typeSymbol.Name));
+
+                    error = false;
+                }
+            }
+        }
+
+        return error;
+    }
+
     static bool TryEmitIndexMetamethod(TypeMetadata typeMetadata, CodeBuilder builder, in SourceProductionContext context)
     {
         builder.AppendLine("static readonly global::Lua.LuaFunction __metamethod_index = new global::Lua.LuaFunction((context, buffer, ct) =>");

+ 1 - 1
src/Lua.SourceGenerator/LuaObjectGenerator.cs

@@ -36,7 +36,7 @@ public partial class LuaObjectGenerator : IIncrementalGenerator
                 {
                     var typeMeta = new TypeMetadata((TypeDeclarationSyntax)x.TargetNode, (INamedTypeSymbol)x.TargetSymbol, references);
 
-                    if (TryEmit(typeMeta, builder, in sourceProductionContext))
+                    if (TryEmit(typeMeta, builder, references, compilation, in sourceProductionContext))
                     {
                         var fullType = typeMeta.FullTypeName
                             .Replace("global::", "")

+ 4 - 0
src/Lua.SourceGenerator/PropertyMetadata.cs

@@ -5,6 +5,7 @@ namespace Lua.SourceGenerator;
 public class PropertyMetadata
 {
     public ISymbol Symbol { get; }
+    public ITypeSymbol Type { get; }
     public string TypeFullName { get; }
     public bool IsStatic { get; }
     public bool IsReadOnly { get; }
@@ -18,16 +19,19 @@ public class PropertyMetadata
 
         if (symbol is IFieldSymbol field)
         {
+            Type = field.Type;
             TypeFullName = field.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
             IsReadOnly = field.IsReadOnly;
         }
         else if (symbol is IPropertySymbol property)
         {
+            Type = property.Type;
             TypeFullName = property.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
             IsReadOnly = property.SetMethod == null;
         }
         else
         {
+            Type = default!;
             TypeFullName = "";
         }
 

+ 2 - 0
src/Lua.SourceGenerator/SymbolReferences.cs

@@ -14,10 +14,12 @@ public sealed class SymbolReferences
             LuaObjectAttribute = luaObjectAttribute,
             LuaMemberAttribute = compilation.GetTypeByMetadataName("Lua.LuaMemberAttribute")!,
             LuaIgnoreMemberAttribute = compilation.GetTypeByMetadataName("Lua.LuaIgnoreMemberAttribute")!,
+            LuaValue = compilation.GetTypeByMetadataName("Lua.LuaValue")!,
         };
     }
 
     public INamedTypeSymbol LuaObjectAttribute { get; private set; } = default!;
     public INamedTypeSymbol LuaMemberAttribute { get; private set; } = default!;
     public INamedTypeSymbol LuaIgnoreMemberAttribute { get; private set; } = default!;
+    public INamedTypeSymbol LuaValue { get; private set; } = default!; 
 }