Browse Source

fix: fix LuaObjectGenerator value conversion handling

Akeit0 7 months ago
parent
commit
f3d299c3e2

+ 51 - 20
src/Lua.SourceGenerator/LuaObjectGenerator.Emit.cs

@@ -6,6 +6,13 @@ namespace Lua.SourceGenerator;
 
 partial class LuaObjectGenerator
 {
+    static string GetLuaValuePrefix(ITypeSymbol typeSymbol, SymbolReferences references,Compilation compilation)
+    {
+        return compilation.ClassifyCommonConversion(typeSymbol, references.LuaUserData).Exists 
+            ? "global::Lua.LuaValue.FromUserData(" 
+            : "(";
+    }
+
     static bool TryEmit(TypeMetadata typeMetadata, CodeBuilder builder, SymbolReferences references, Compilation compilation, in SourceProductionContext context, Dictionary<INamedTypeSymbol, TypeMetadata> metaDict)
     {
         try
@@ -82,17 +89,17 @@ partial class LuaObjectGenerator
 
             var metamethodSet = new HashSet<LuaObjectMetamethod>();
 
-            if (!TryEmitMethods(typeMetadata, builder, references, metamethodSet, context))
+            if (!TryEmitMethods(typeMetadata, builder, references,compilation, metamethodSet, context))
             {
                 return false;
             }
 
-            if (!TryEmitIndexMetamethod(typeMetadata, builder, context))
+            if (!TryEmitIndexMetamethod(typeMetadata, builder,references,compilation, context))
             {
                 return false;
             }
 
-            if (!TryEmitNewIndexMetamethod(typeMetadata, builder, context))
+            if (!TryEmitNewIndexMetamethod(typeMetadata, builder,references, context))
             {
                 return false;
             }
@@ -106,7 +113,7 @@ partial class LuaObjectGenerator
             builder.AppendLine($"public static implicit operator global::Lua.LuaValue({typeMetadata.FullTypeName} value)");
             using (builder.BeginBlockScope())
             {
-                builder.AppendLine("return new(value);");
+                builder.AppendLine("return  global::Lua.LuaValue.FromUserData(value);");
             }
 
             if (!ns.IsGlobalNamespace) builder.EndBlock();
@@ -133,7 +140,9 @@ partial class LuaObjectGenerator
         foreach (var property in typeMetadata.Properties)
         {
             if (SymbolEqualityComparer.Default.Equals(property.Type, references.LuaValue)) continue;
+            if (SymbolEqualityComparer.Default.Equals(property.Type, references.LuaUserData)) continue;
             if (SymbolEqualityComparer.Default.Equals(property.Type, typeMetadata.Symbol)) continue;
+            if(compilation.ClassifyConversion(property.Type, references.LuaUserData).Exists)continue;
 
             var conversion = compilation.ClassifyConversion(property.Type, references.LuaValue);
             if (!conversion.Exists && (property.Type is not INamedTypeSymbol namedTypeSymbol || !metaDict.ContainsKey(namedTypeSymbol)))
@@ -162,7 +171,9 @@ partial class LuaObjectGenerator
                 }
 
                 if (SymbolEqualityComparer.Default.Equals(typeSymbol, references.LuaValue)) goto PARAMETERS;
+                if (SymbolEqualityComparer.Default.Equals(typeSymbol, references.LuaUserData)) goto PARAMETERS;
                 if (SymbolEqualityComparer.Default.Equals(typeSymbol, typeMetadata.Symbol)) goto PARAMETERS;
+                if(compilation.ClassifyConversion(typeSymbol, references.LuaUserData).Exists) goto PARAMETERS;
 
                 var conversion = compilation.ClassifyConversion(typeSymbol, references.LuaValue);
                 if (!conversion.Exists && (typeSymbol is not INamedTypeSymbol namedTypeSymbol || !metaDict.ContainsKey(namedTypeSymbol)))
@@ -186,7 +197,9 @@ partial class LuaObjectGenerator
                     continue;
                 }
                 if (SymbolEqualityComparer.Default.Equals(typeSymbol, references.LuaValue)) continue;
+                if (SymbolEqualityComparer.Default.Equals(typeSymbol, references.LuaUserData)) continue;
                 if (SymbolEqualityComparer.Default.Equals(typeSymbol, typeMetadata.Symbol)) continue;
+                if(compilation.ClassifyConversion(typeSymbol, references.LuaUserData).Exists) continue;
 
                 var conversion = compilation.ClassifyConversion(typeSymbol, references.LuaValue);
                 if (!conversion.Exists && (typeSymbol is not INamedTypeSymbol namedTypeSymbol || !metaDict.ContainsKey(namedTypeSymbol)))
@@ -204,7 +217,7 @@ partial class LuaObjectGenerator
         return isValid;
     }
 
-    static bool TryEmitIndexMetamethod(TypeMetadata typeMetadata, CodeBuilder builder, in SourceProductionContext context)
+    static bool TryEmitIndexMetamethod(TypeMetadata typeMetadata, CodeBuilder builder, SymbolReferences references, Compilation compilation,in SourceProductionContext context)
     {
         builder.AppendLine(@"static readonly global::Lua.LuaFunction __metamethod_index = new global::Lua.LuaFunction(""index"", (context, ct) =>");
 
@@ -218,13 +231,14 @@ partial class LuaObjectGenerator
             {
                 foreach (var propertyMetadata in typeMetadata.Properties)
                 {
+                    var conversionPrefix =GetLuaValuePrefix(propertyMetadata.Type,references,compilation);
                     if (propertyMetadata.IsStatic)
                     {
-                        builder.AppendLine(@$"""{propertyMetadata.LuaMemberName}"" => new global::Lua.LuaValue({typeMetadata.FullTypeName}.{propertyMetadata.Symbol.Name}),");
+                        builder.AppendLine(@$"""{propertyMetadata.LuaMemberName}"" => {conversionPrefix}{typeMetadata.FullTypeName}.{propertyMetadata.Symbol.Name}),");
                     }
                     else
                     {
-                        builder.AppendLine(@$"""{propertyMetadata.LuaMemberName}"" => new global::Lua.LuaValue(userData.{propertyMetadata.Symbol.Name}),");
+                        builder.AppendLine(@$"""{propertyMetadata.LuaMemberName}"" => {conversionPrefix}userData.{propertyMetadata.Symbol.Name}),");
                     }
                 }
 
@@ -247,7 +261,7 @@ partial class LuaObjectGenerator
         return true;
     }
 
-    static bool TryEmitNewIndexMetamethod(TypeMetadata typeMetadata, CodeBuilder builder, in SourceProductionContext context)
+    static bool TryEmitNewIndexMetamethod(TypeMetadata typeMetadata, CodeBuilder builder,SymbolReferences references,  in SourceProductionContext context)
     {
         builder.AppendLine(@"static readonly global::Lua.LuaFunction __metamethod_newindex = new global::Lua.LuaFunction(""newindex"", (context, ct) =>");
 
@@ -271,12 +285,27 @@ partial class LuaObjectGenerator
                         }
                         else if (propertyMetadata.IsStatic)
                         {
-                            builder.AppendLine(@$"{typeMetadata.FullTypeName}.{propertyMetadata.Symbol.Name} = context.GetArgument<{propertyMetadata.TypeFullName}>(2);");
+                            if (SymbolEqualityComparer.Default.Equals(propertyMetadata.Type, references.LuaValue))
+                            {
+                                builder.AppendLine($"{typeMetadata.FullTypeName}.{propertyMetadata.Symbol.Name} = context.GetArgument(2);");
+                            }
+                            else
+                            {
+                                builder.AppendLine($"{typeMetadata.FullTypeName}.{propertyMetadata.Symbol.Name} = context.GetArgument<{propertyMetadata.TypeFullName}>(2);");
+                            }
+
                             builder.AppendLine("break;");
                         }
                         else
                         {
-                            builder.AppendLine(@$"userData.{propertyMetadata.Symbol.Name} = context.GetArgument<{propertyMetadata.TypeFullName}>(2);");
+                            if (SymbolEqualityComparer.Default.Equals(propertyMetadata.Type, references.LuaValue))
+                            {
+                                builder.AppendLine($"userData.{propertyMetadata.Symbol.Name} = context.GetArgument(2);");
+                            }
+                            else
+                            {
+                                builder.AppendLine($"userData.{propertyMetadata.Symbol.Name} = context.GetArgument<{propertyMetadata.TypeFullName}>(2);");
+                            }
                             builder.AppendLine("break;");
                         }
                     }
@@ -309,7 +338,7 @@ partial class LuaObjectGenerator
         return true;
     }
 
-    static bool TryEmitMethods(TypeMetadata typeMetadata, CodeBuilder builder, SymbolReferences references, HashSet<LuaObjectMetamethod> metamethodSet, in SourceProductionContext context)
+    static bool TryEmitMethods(TypeMetadata typeMetadata, CodeBuilder builder, SymbolReferences references,Compilation compilation, HashSet<LuaObjectMetamethod> metamethodSet, in SourceProductionContext context)
     {
         builder.AppendLine();
 
@@ -320,7 +349,7 @@ partial class LuaObjectGenerator
             if (methodMetadata.HasMemberAttribute)
             {
                 functionName = $"__function_{methodMetadata.LuaMemberName}";
-                EmitMethodFunction(functionName, methodMetadata.LuaMemberName, typeMetadata, methodMetadata, builder, references);
+                EmitMethodFunction(functionName, methodMetadata.LuaMemberName, typeMetadata, methodMetadata, builder, references,compilation);
             }
 
             if (methodMetadata.HasMetamethodAttribute)
@@ -339,7 +368,7 @@ partial class LuaObjectGenerator
 
                 if (functionName == null)
                 {
-                    EmitMethodFunction($"__metamethod_{methodMetadata.Metamethod}", methodMetadata.Metamethod.ToString().ToLower(), typeMetadata, methodMetadata, builder, references);
+                    EmitMethodFunction($"__metamethod_{methodMetadata.Metamethod}", methodMetadata.Metamethod.ToString().ToLower(), typeMetadata, methodMetadata, builder, references,compilation);
                 }
                 else
                 {
@@ -351,7 +380,7 @@ partial class LuaObjectGenerator
         return true;
     }
 
-    static void EmitMethodFunction(string functionName, string chunkName, TypeMetadata typeMetadata, MethodMetadata methodMetadata, CodeBuilder builder, SymbolReferences references)
+    static void EmitMethodFunction(string functionName, string chunkName, TypeMetadata typeMetadata, MethodMetadata methodMetadata, CodeBuilder builder, SymbolReferences references,Compilation compilation)
     {
         builder.AppendLine($@"static readonly global::Lua.LuaFunction {functionName} = new global::Lua.LuaFunction(""{chunkName}"", {(methodMetadata.IsAsync ? "async" : "")} (context, ct) =>");
 
@@ -445,14 +474,16 @@ partial class LuaObjectGenerator
             builder.Append("return ");
             if (methodMetadata.HasReturnValue)
             {
-                if (SymbolEqualityComparer.Default.Equals(methodMetadata.Symbol.ReturnType, references.LuaValue))
+                var returnType = methodMetadata.Symbol.ReturnType;
+                if (methodMetadata.IsAsync)
                 {
-                    builder.AppendLine(methodMetadata.IsAsync ? "context.Return(result));" : "new global::System.Threading.Tasks.ValueTask<int>(context.Return(result));");
-                }
-                else
-                {
-                    builder.AppendLine(methodMetadata.IsAsync ? "context.Return(new global::Lua.LuaValue(result));" : "new global::System.Threading.Tasks.ValueTask<int>(context.Return(new global::Lua.LuaValue(result)));");
+                    var namedType = (INamedTypeSymbol)returnType;
+                    if (namedType.TypeArguments.Length == 1) returnType = namedType.TypeArguments[0];
                 }
+
+                var conversionPrefix =GetLuaValuePrefix(returnType,references,compilation);
+                builder.AppendLine(methodMetadata.IsAsync ? $"context.Return({conversionPrefix}result));" : $"new global::System.Threading.Tasks.ValueTask<int>(context.Return({conversionPrefix}result)));");
+                
             }
             else
             {

+ 6 - 1
src/Lua/LuaValue.cs

@@ -396,8 +396,13 @@ public readonly struct LuaValue : IEquatable<LuaValue>
         return true;
     }
 
+    public static LuaValue FromUserData(ILuaUserData userData)
+    {
+        return new (userData);
+    }
+
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    public LuaValue(object obj)
+    internal LuaValue(object obj)
     {
         Type = LuaValueType.LightUserData;
         referenceValue = obj;

+ 7 - 6
tests/Lua.Tests/LuaObjectTests.cs

@@ -7,7 +7,8 @@ public partial class TestUserData
 {
     [LuaMember]
     public int Property { get; set; }
-
+    [LuaMember]
+    public LuaValue LuaValueProperty { get; set; }
     [LuaMember("p2")]
     public string PropertyWithName { get; set; } = "";
 
@@ -37,10 +38,10 @@ public partial class TestUserData
     }
 
     [LuaMember]
-    public async Task<double> InstanceMethodWithReturnValueAsync()
+    public async ValueTask<LuaValue> InstanceMethodWithReturnValueAsync(LuaValue value,CancellationToken ct)
     {
-        await Task.Delay(1);
-        return Property;
+        await Task.Delay(1,ct);
+        return value;
     }
 
     [LuaMetamethod(LuaObjectMetamethod.Call)]
@@ -135,10 +136,10 @@ public class LuaObjectTests
 
         var state = LuaState.Create();
         state.Environment["test"] = userData;
-        var results = await state.DoStringAsync("return test:InstanceMethodWithReturnValueAsync()");
+        var results = await state.DoStringAsync("return test:InstanceMethodWithReturnValueAsync(2)");
 
         Assert.That(results, Has.Length.EqualTo(1));
-        Assert.That(results[0], Is.EqualTo(new LuaValue(1)));
+        Assert.That(results[0], Is.EqualTo(new LuaValue(2)));
     }
 
     [Test]