Browse Source

Update: support metamethod

AnnulusGames 1 year ago
parent
commit
6d33fc77b7

+ 8 - 0
src/Lua.SourceGenerator/DiagnosticDescriptors.cs

@@ -45,4 +45,12 @@ public static class DiagnosticDescriptors
         category: Category,
         defaultSeverity: DiagnosticSeverity.Error,
         isEnabledByDefault: true);
+
+    public static readonly DiagnosticDescriptor DuplicateMetamethod = new(
+        id: "LUAC006",
+        title: "The type already contains same metamethod.",
+        messageFormat: "Type '{0}' already contains a '{1}' metamethod.,",
+        category: Category,
+        defaultSeverity: DiagnosticSeverity.Error,
+        isEnabledByDefault: true);
 }

+ 101 - 39
src/Lua.SourceGenerator/LuaObjectGenerator.Emit.cs

@@ -79,49 +79,35 @@ partial class LuaObjectGenerator
 
             using var _ = builder.BeginBlockScope($"partial {typeDeclarationKeyword} {typeMetadata.TypeName} : global::Lua.ILuaUserData");
 
-            // add ILuaUserData impl
-            builder.Append(
-$$"""
-        global::Lua.LuaTable? global::Lua.ILuaUserData.Metatable
-        {
-            get
-            {
-                if (__metatable != null) return __metatable;
+            var metamethodSet = new HashSet<LuaObjectMetamethod>();
 
-                __metatable = new();
-                __metatable[global::Lua.Runtime.Metamethods.Index] = __metamethod_index;
-                __metatable[global::Lua.Runtime.Metamethods.NewIndex] = __metamethod_newindex;
-                return __metatable;
-            }
-            set
+            if (!TryEmitMethods(typeMetadata, builder, metamethodSet, context))
             {
-                __metatable = value;
+                return false;
             }
-        }
-        static global::Lua.LuaTable? __metatable;
-
-        public static implicit operator global::Lua.LuaValue({{typeMetadata.FullTypeName}} value)
-        {
-            return new(value);
-        }
-
-""", false);
 
-            if (!TryEmitMethods(typeMetadata, builder, context))
+            if (!TryEmitIndexMetamethod(typeMetadata, builder, context))
             {
                 return false;
             }
 
-            if (!TryEmitIndexMetamethod(typeMetadata, builder, context))
+            if (!TryEmitNewIndexMetamethod(typeMetadata, builder, context))
             {
                 return false;
             }
 
-            if (!TryEmitNewIndexMetamethod(typeMetadata, builder, context))
+            if (!TryEmitMetatable(builder, metamethodSet, context))
             {
                 return false;
             }
 
+            // implicit operator
+            builder.AppendLine($"public static implicit operator global::Lua.LuaValue({typeMetadata.FullTypeName} value)");
+            using (builder.BeginBlockScope())
+            {
+                builder.AppendLine("return new(value);");
+            }
+
             if (!ns.IsGlobalNamespace) builder.EndBlock();
 
             builder.AppendLine("#pragma warning restore CS0162 // Unreachable code");
@@ -209,7 +195,8 @@ $$"""
                     }
                 }
 
-                foreach (var methodMetadata in typeMetadata.Methods)
+                foreach (var methodMetadata in typeMetadata.Methods
+                    .Where(x => x.HasMemberAttribute))
                 {
                     builder.AppendLine(@$"""{methodMetadata.LuaMemberName}"" => new global::Lua.LuaValue(__function_{methodMetadata.LuaMemberName}),");
                 }
@@ -262,7 +249,8 @@ $$"""
                     }
                 }
 
-                foreach (var methodMetadata in typeMetadata.Methods)
+                foreach (var methodMetadata in typeMetadata.Methods
+                    .Where(x => x.HasMemberAttribute))
                 {
                     builder.AppendLine(@$"case ""{methodMetadata.LuaMemberName}"":");
 
@@ -288,18 +276,22 @@ $$"""
         return true;
     }
 
-    static bool TryEmitMethods(TypeMetadata typeMetadata, CodeBuilder builder, in SourceProductionContext context)
+    static bool TryEmitMethods(TypeMetadata typeMetadata, CodeBuilder builder, HashSet<LuaObjectMetamethod> metamethodSet, in SourceProductionContext context)
     {
-        builder.AppendLine();
-
-        foreach (var methodMetadata in typeMetadata.Methods)
+        static void EmitMethodFunction(string functionName, TypeMetadata typeMetadata, MethodMetadata methodMetadata, CodeBuilder builder)
         {
-            builder.AppendLine($"static readonly global::Lua.LuaFunction __function_{methodMetadata.LuaMemberName} = new global::Lua.LuaFunction((context, buffer, ct) =>");
+            builder.AppendLine($"static readonly global::Lua.LuaFunction {functionName} = new global::Lua.LuaFunction((context, buffer, ct) =>");
 
             using (builder.BeginBlockScope())
             {
                 var index = 0;
 
+                if (!methodMetadata.IsStatic)
+                {
+                    builder.AppendLine($"var userData = context.GetArgument<{typeMetadata.FullTypeName}>(0);");
+                    index++;
+                }
+
                 foreach (var parameter in methodMetadata.Symbol.Parameters)
                 {
                     builder.AppendLine($"var arg{index} = context.GetArgument<{parameter.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>({index});");
@@ -309,24 +301,94 @@ $$"""
                 if (methodMetadata.IsStatic)
                 {
                     builder.Append($"var result = {typeMetadata.FullTypeName}.{methodMetadata.Symbol.Name}(");
-                    builder.Append(string.Join(",", Enumerable.Range(0, index).Select(x => $"arg{x}")));
-                    builder.AppendLine(");");
+                    builder.Append(string.Join(",", Enumerable.Range(0, index).Select(x => $"arg{x}")), false);
+                    builder.AppendLine(");", false);
                 }
                 else
                 {
 
                     builder.Append($"var result = userData.{methodMetadata.Symbol.Name}(");
-                    builder.Append(string.Join(",", Enumerable.Range(1, index).Select(x => $"arg{x}")));
-                    builder.AppendLine(");");
+                    builder.Append(string.Join(",", Enumerable.Range(1, index - 1).Select(x => $"arg{x}")), false);
+                    builder.AppendLine(");", false);
                 }
 
                 builder.AppendLine("buffer.Span[0] = new global::Lua.LuaValue(result);");
                 builder.AppendLine("return new(1);");
             }
-
             builder.AppendLine(");");
+            builder.AppendLine();
         }
 
+        builder.AppendLine();
+
+        foreach (var methodMetadata in typeMetadata.Methods)
+        {
+            string? functionName = null;
+
+            if (methodMetadata.HasMemberAttribute)
+            {
+                functionName = $"__function_{methodMetadata.LuaMemberName}";
+                EmitMethodFunction(functionName, typeMetadata, methodMetadata, builder);
+            }
+
+            if (methodMetadata.HasMetamethodAttribute)
+            {
+                if (!metamethodSet.Add(methodMetadata.Metamethod))
+                {
+                    context.ReportDiagnostic(Diagnostic.Create(
+                        DiagnosticDescriptors.DuplicateMetamethod,
+                        methodMetadata.Symbol.Locations.FirstOrDefault(),
+                        typeMetadata.TypeName,
+                        methodMetadata.Metamethod
+                    ));
+
+                    continue;
+                }
+
+                if (functionName == null)
+                {
+                    EmitMethodFunction($"__metamethod_{methodMetadata.Metamethod}", typeMetadata, methodMetadata, builder);
+                }
+                else
+                {
+                    builder.AppendLine($"static global::Lua.LuaFunction __metamethod_{methodMetadata.Metamethod} => {functionName};");
+                }
+            }
+        }
+
+        return true;
+    }
+
+    static bool TryEmitMetatable(CodeBuilder builder, IEnumerable<LuaObjectMetamethod> metamethods, in SourceProductionContext context)
+    {
+        builder.AppendLine("global::Lua.LuaTable? global::Lua.ILuaUserData.Metatable");
+        using (builder.BeginBlockScope())
+        {
+            builder.AppendLine("get");
+            using (builder.BeginBlockScope())
+            {
+                builder.AppendLine("if (__metatable != null) return __metatable;");
+                builder.AppendLine();
+                builder.AppendLine("__metatable = new();");
+                builder.AppendLine("__metatable[global::Lua.Runtime.Metamethods.Index] = __metamethod_index;");
+                builder.AppendLine("__metatable[global::Lua.Runtime.Metamethods.NewIndex] = __metamethod_newindex;");
+                foreach (var metamethod in metamethods)
+                {
+                    builder.AppendLine($"__metatable[global::Lua.Runtime.Metamethods.{metamethod}] = __metamethod_{metamethod};");
+                }
+                builder.AppendLine("return __metatable;");
+            }
+
+            builder.AppendLine("set");
+            using (builder.BeginBlockScope())
+            {
+                builder.AppendLine("__metatable = value;");
+            }
+        }
+
+        builder.AppendLine("static global::Lua.LuaTable? __metatable;");
+        builder.AppendLine();
+
         return true;
     }
 }

+ 23 - 0
src/Lua.SourceGenerator/LuaObjectMetamethod.cs

@@ -0,0 +1,23 @@
+namespace Lua.SourceGenerator;
+
+// same as Lua.LuaObjectMetamethod
+
+internal enum LuaObjectMetamethod
+{
+    Add,
+    Sub,
+    Mul,
+    Div,
+    Mod,
+    Pow,
+    Unm,
+    Len,
+    Eq,
+    Lt,
+    Le,
+    Call,
+    Concat,
+    Pairs,
+    IPairs,
+    ToString,
+}

+ 14 - 1
src/Lua.SourceGenerator/MethodMetadata.cs

@@ -2,11 +2,14 @@ using Microsoft.CodeAnalysis;
 
 namespace Lua.SourceGenerator;
 
-public class MethodMetadata
+internal class MethodMetadata
 {
     public IMethodSymbol Symbol { get; }
     public bool IsStatic { get; }
+    public bool HasMemberAttribute { get; }
+    public bool HasMetamethodAttribute { get; }
     public string LuaMemberName { get; }
+    public LuaObjectMetamethod Metamethod { get; }
 
     public MethodMetadata(IMethodSymbol symbol, SymbolReferences references)
     {
@@ -16,6 +19,8 @@ public class MethodMetadata
         LuaMemberName = symbol.Name;
 
         var memberAttribute = symbol.GetAttribute(references.LuaMemberAttribute);
+        HasMemberAttribute = memberAttribute != null;
+
         if (memberAttribute != null)
         {
             if (memberAttribute.ConstructorArguments.Length > 0)
@@ -27,5 +32,13 @@ public class MethodMetadata
                 }
             }
         }
+
+        var metamethodAttribute = symbol.GetAttribute(references.LuaMetamethodAttribute);
+        HasMetamethodAttribute = metamethodAttribute != null;
+
+        if (metamethodAttribute != null)
+        {
+            Metamethod = (LuaObjectMetamethod)Enum.Parse(typeof(LuaObjectMetamethod), metamethodAttribute.ConstructorArguments[0].Value!.ToString());
+        }
     }
 }

+ 2 - 0
src/Lua.SourceGenerator/SymbolReferences.cs

@@ -14,6 +14,7 @@ public sealed class SymbolReferences
             LuaObjectAttribute = luaObjectAttribute,
             LuaMemberAttribute = compilation.GetTypeByMetadataName("Lua.LuaMemberAttribute")!,
             LuaIgnoreMemberAttribute = compilation.GetTypeByMetadataName("Lua.LuaIgnoreMemberAttribute")!,
+            LuaMetamethodAttribute = compilation.GetTypeByMetadataName("Lua.LuaMetamethodAttribute")!,
             LuaValue = compilation.GetTypeByMetadataName("Lua.LuaValue")!,
         };
     }
@@ -21,5 +22,6 @@ public sealed class SymbolReferences
     public INamedTypeSymbol LuaObjectAttribute { get; private set; } = default!;
     public INamedTypeSymbol LuaMemberAttribute { get; private set; } = default!;
     public INamedTypeSymbol LuaIgnoreMemberAttribute { get; private set; } = default!;
+    public INamedTypeSymbol LuaMetamethodAttribute { get; private set; } = default!;
     public INamedTypeSymbol LuaValue { get; private set; } = default!; 
 }

+ 2 - 4
src/Lua.SourceGenerator/TypeMetadata.cs

@@ -4,7 +4,7 @@ using Microsoft.CodeAnalysis.CSharp.Syntax;
 
 namespace Lua.SourceGenerator;
 
-internal record class TypeMetadata
+internal class TypeMetadata
 {
     public TypeDeclarationSyntax Syntax { get; }
     public INamedTypeSymbol Symbol { get; }
@@ -44,10 +44,8 @@ internal record class TypeMetadata
             .Select(x => (IMethodSymbol)x)
             .Where(x =>
             {
-                if (!x.ContainsAttribute(references.LuaMemberAttribute)) return false;
                 if (x.ContainsAttribute(references.LuaIgnoreMemberAttribute)) return false;
-
-                return true;
+                return x.ContainsAttribute(references.LuaMemberAttribute) || x.ContainsAttribute(references.LuaMetamethodAttribute);
             })
             .Select(x => new MethodMetadata(x, references))
             .ToArray();

+ 6 - 0
src/Lua/Attributes.cs

@@ -30,6 +30,12 @@ public sealed class LuaMemberAttribute : Attribute
     public string? Name { get; }
 }
 
+[AttributeUsage(AttributeTargets.Method)]
+public sealed class LuaMetamethodAttribute(LuaObjectMetamethod metamethod) : Attribute
+{
+    public LuaObjectMetamethod Metamethod { get; } = metamethod;
+}
+
 [AttributeUsage(AttributeTargets.Method | AttributeTargets.Field | AttributeTargets.Property)]
 public sealed class LuaIgnoreMemberAttribute : Attribute
 {

+ 0 - 8
src/Lua/LuaObjectGenerateOptions.cs

@@ -1,8 +0,0 @@
-namespace Lua;
-
-[Flags]
-public enum LuaObjectGenerateOptions
-{
-    None = 0,
-    
-}

+ 21 - 0
src/Lua/LuaObjectMetamethod.cs

@@ -0,0 +1,21 @@
+namespace Lua;
+
+public enum LuaObjectMetamethod
+{
+    Add,
+    Sub,
+    Mul,
+    Div,
+    Mod,
+    Pow,
+    Unm,
+    Len,
+    Eq,
+    Lt,
+    Le,
+    Call,
+    Concat,
+    Pairs,
+    IPairs,
+    ToString,
+}