Sfoglia il codice sorgente

Merge pull request #258 from nuskey8/fix/value-clear-after-method

Fix: meta method clears other values
Akito Inoue 1 settimana fa
parent
commit
067ce2bea5

+ 33 - 19
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -215,10 +215,11 @@ public static partial class LuaVirtualMachine
             State.PopCallStackFrameUntil(BaseCallStackCount);
             State.PopCallStackFrameUntil(BaseCallStackCount);
         }
         }
 
 
-        bool ExecutePostOperation(PostOperationType postOperation)
+        bool ExecutePostOperation(int varArgs, PostOperationType postOperation)
         {
         {
             var stackCount = Stack.Count;
             var stackCount = Stack.Count;
             var resultsSpan = Stack.GetBuffer()[CurrentReturnFrameBase..];
             var resultsSpan = Stack.GetBuffer()[CurrentReturnFrameBase..];
+            var lastTop = CurrentReturnFrameBase - varArgs;
             switch (postOperation)
             switch (postOperation)
             {
             {
                 case PostOperationType.Nop: break;
                 case PostOperationType.Nop: break;
@@ -226,7 +227,7 @@ public static partial class LuaVirtualMachine
                     var RA = Instruction.A + FrameBase;
                     var RA = Instruction.A + FrameBase;
                     Stack.Get(RA) = stackCount > CurrentReturnFrameBase ? Stack.Get(CurrentReturnFrameBase) : LuaValue.Nil;
                     Stack.Get(RA) = stackCount > CurrentReturnFrameBase ? Stack.Get(CurrentReturnFrameBase) : LuaValue.Nil;
                     Stack.NotifyTop(RA + 1);
                     Stack.NotifyTop(RA + 1);
-                    Stack.PopUntil(RA + 1);
+                    Stack.PopUntil(Math.Max(RA + 1, lastTop));
                     break;
                     break;
                 case PostOperationType.TForCall:
                 case PostOperationType.TForCall:
                     TForCallPostOperation(this);
                     TForCallPostOperation(this);
@@ -242,10 +243,10 @@ public static partial class LuaVirtualMachine
 
 
                     break;
                     break;
                 case PostOperationType.Self:
                 case PostOperationType.Self:
-                    SelfPostOperation(this, resultsSpan);
+                    SelfPostOperation(this, lastTop, resultsSpan);
                     break;
                     break;
                 case PostOperationType.Compare:
                 case PostOperationType.Compare:
-                    ComparePostOperation(this, resultsSpan);
+                    ComparePostOperation(this, lastTop, resultsSpan);
                     break;
                     break;
             }
             }
 
 
@@ -263,13 +264,15 @@ public static partial class LuaVirtualMachine
                     toCatchFlag = true;
                     toCatchFlag = true;
                     await Task;
                     await Task;
                     Task = default;
                     Task = default;
-                    CurrentReturnFrameBase = State.GetCurrentFrame().ReturnBase;
+                    ref readonly var frame = ref State.GetCurrentFrame();
+                    CurrentReturnFrameBase = frame.ReturnBase;
+                    var variableArgumentCount = frame.VariableArgumentCount;
                     if (PostOperation is not (PostOperationType.TailCall or PostOperationType.DontPop))
                     if (PostOperation is not (PostOperationType.TailCall or PostOperationType.DontPop))
                     {
                     {
                         State.PopCallStackFrame();
                         State.PopCallStackFrame();
                     }
                     }
 
 
-                    if (!ExecutePostOperation(PostOperation))
+                    if (!ExecutePostOperation(variableArgumentCount, PostOperation))
                     {
                     {
                         break;
                         break;
                     }
                     }
@@ -351,7 +354,7 @@ public static partial class LuaVirtualMachine
     {
     {
         try
         try
         {
         {
-        // This is a label to restart the execution when new function is called or restarted
+            // This is a label to restart the execution when new function is called or restarted
         Restart:
         Restart:
             ref var instructionsHead = ref Unsafe.AsRef(in context.Prototype.Code[0]);
             ref var instructionsHead = ref Unsafe.AsRef(in context.Prototype.Code[0]);
             var frameBase = context.FrameBase;
             var frameBase = context.FrameBase;
@@ -412,6 +415,7 @@ public static partial class LuaVirtualMachine
                         {
                         {
                             context.Pc++;
                             context.Pc++;
                         }
                         }
+
                         continue;
                         continue;
                     case OpCode.LoadNil:
                     case OpCode.LoadNil:
                         Markers.LoadNil();
                         Markers.LoadNil();
@@ -958,7 +962,7 @@ public static partial class LuaVirtualMachine
         return mod;
         return mod;
     }
     }
 
 
-    static void SelfPostOperation(VirtualMachineExecutionContext context, Span<LuaValue> results)
+    static void SelfPostOperation(VirtualMachineExecutionContext context, int lastTop, Span<LuaValue> results)
     {
     {
         var stack = context.Stack;
         var stack = context.Stack;
         var instruction = context.Instruction;
         var instruction = context.Instruction;
@@ -968,6 +972,7 @@ public static partial class LuaVirtualMachine
         var table = Unsafe.Add(ref stackHead, RB);
         var table = Unsafe.Add(ref stackHead, RB);
         Unsafe.Add(ref stackHead, RA + 1) = table;
         Unsafe.Add(ref stackHead, RA + 1) = table;
         Unsafe.Add(ref stackHead, RA) = results.Length == 0 ? LuaValue.Nil : results[0];
         Unsafe.Add(ref stackHead, RA) = results.Length == 0 ? LuaValue.Nil : results[0];
+        stack.PopUntil(Math.Max(RA + 2, lastTop));
         stack.NotifyTop(RA + 2);
         stack.NotifyTop(RA + 2);
     }
     }
 
 
@@ -1558,7 +1563,7 @@ public static partial class LuaVirtualMachine
         stack.PopUntil(RA + 1);
         stack.PopUntil(RA + 1);
     }
     }
 
 
-    static void ComparePostOperation(VirtualMachineExecutionContext context, Span<LuaValue> results)
+    static void ComparePostOperation(VirtualMachineExecutionContext context, int lastTop, Span<LuaValue> results)
     {
     {
         var compareResult = results.Length != 0 && results[0].ToBoolean();
         var compareResult = results.Length != 0 && results[0].ToBoolean();
         if (compareResult != (context.Instruction.A == 1))
         if (compareResult != (context.Instruction.A == 1))
@@ -1567,6 +1572,7 @@ public static partial class LuaVirtualMachine
         }
         }
 
 
         results.Clear();
         results.Clear();
+        context.Stack.PopUntil(lastTop);
     }
     }
 
 
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -1979,7 +1985,7 @@ public static partial class LuaVirtualMachine
             stack.Push(vc);
             stack.Push(vc);
             var (argCount, variableArgumentCount) = PrepareForFunctionCall(context.State, func, newBase);
             var (argCount, variableArgumentCount) = PrepareForFunctionCall(context.State, func, newBase);
             newBase += variableArgumentCount;
             newBase += variableArgumentCount;
-            var newFrame = func.CreateNewFrame(context, newBase, context.Instruction.A + context.FrameBase, variableArgumentCount);
+            var newFrame = func.CreateNewFrame(context, newBase, newBase, variableArgumentCount);
 
 
             context.State.PushCallStackFrame(newFrame);
             context.State.PushCallStackFrame(newFrame);
             if (context.State.CallOrReturnHookMask.Value != 0 && !context.State.IsInHook)
             if (context.State.CallOrReturnHookMask.Value != 0 && !context.State.IsInHook)
@@ -2007,12 +2013,16 @@ public static partial class LuaVirtualMachine
                 return false;
                 return false;
             }
             }
 
 
-            if (task.GetAwaiter().GetResult() == 0)
+            var destIndex = context.Instruction.A + context.FrameBase;
+            if (newFrame.ReturnBase != destIndex)
             {
             {
-                stack.Get(newFrame.ReturnBase) = default;
+                stack.Get(destIndex)
+                    = task.GetAwaiter().GetResult() != 0
+                        ? stack.FastGet(newFrame.ReturnBase)
+                        : default;
             }
             }
 
 
-            stack.PopUntil(newFrame.ReturnBase + 1);
+            stack.PopUntil(Math.Max(destIndex + 1, newFrame.Base - newFrame.VariableArgumentCount));
             context.State.PopCallStackFrame();
             context.State.PopCallStackFrame();
             return true;
             return true;
         }
         }
@@ -2107,7 +2117,7 @@ public static partial class LuaVirtualMachine
             var (argCount, variableArgumentCount) = PrepareForFunctionCall(context.State, func, newBase);
             var (argCount, variableArgumentCount) = PrepareForFunctionCall(context.State, func, newBase);
             newBase += variableArgumentCount;
             newBase += variableArgumentCount;
 
 
-            var newFrame = func.CreateNewFrame(context, newBase, context.Instruction.A + context.FrameBase, variableArgumentCount);
+            var newFrame = func.CreateNewFrame(context, newBase, newBase, variableArgumentCount);
 
 
             context.State.PushCallStackFrame(newFrame);
             context.State.PushCallStackFrame(newFrame);
             if (context.State.CallOrReturnHookMask.Value != 0 && !context.State.IsInHook)
             if (context.State.CallOrReturnHookMask.Value != 0 && !context.State.IsInHook)
@@ -2135,12 +2145,16 @@ public static partial class LuaVirtualMachine
                 return false;
                 return false;
             }
             }
 
 
-            if (task.GetAwaiter().GetResult() == 0)
+            var destIndex = context.Instruction.A + context.FrameBase;
+            if (newFrame.ReturnBase != destIndex)
             {
             {
-                stack.Get(newFrame.ReturnBase) = default;
+                stack.Get(destIndex)
+                    = task.GetAwaiter().GetResult() != 0
+                        ? stack.FastGet(newFrame.ReturnBase)
+                        : default;
             }
             }
 
 
-            stack.PopUntil(newFrame.ReturnBase + 1);
+            stack.PopUntil(Math.Max(destIndex + 1, newFrame.Base - newFrame.VariableArgumentCount));
             context.State.PopCallStackFrame();
             context.State.PopCallStackFrame();
             return true;
             return true;
         }
         }
@@ -2279,14 +2293,14 @@ public static partial class LuaVirtualMachine
             }
             }
 
 
             var results = stack.AsSpan()[newFrame.ReturnBase..];
             var results = stack.AsSpan()[newFrame.ReturnBase..];
-            var compareResult = results.Length == 0 && results[0].ToBoolean();
+            var compareResult = results.Length > 0 && results[0].ToBoolean();
             compareResult = reverseLe ? !compareResult : compareResult;
             compareResult = reverseLe ? !compareResult : compareResult;
             if (compareResult != (context.Instruction.A == 1))
             if (compareResult != (context.Instruction.A == 1))
             {
             {
                 context.Pc++;
                 context.Pc++;
             }
             }
 
 
-            stack.PopUntil(newFrame.ReturnBase + 1);
+            stack.PopUntil(newFrame.Base - newFrame.VariableArgumentCount + 1);
             context.State.PopCallStackFrame();
             context.State.PopCallStackFrame();
 
 
             return true;
             return true;

+ 15 - 4
tests/Lua.Tests/LuaObjectTests.cs

@@ -47,6 +47,12 @@ public partial class LuaTestObj
         await Task.Delay(1);
         await Task.Delay(1);
         return x + y;
         return x + y;
     }
     }
+    
+    [LuaMetamethod(LuaObjectMetamethod.Unm)]
+    public LuaTestObj Unm()
+    {
+        return new LuaTestObj() { x = -x, y = -y };
+    }
 
 
     [LuaMember]
     [LuaMember]
     public object GetObj() => this;
     public object GetObj() => this;
@@ -246,12 +252,17 @@ public class LuaObjectTests
         state.Environment["TestObj"] = userData;
         state.Environment["TestObj"] = userData;
         var results = await state.DoStringAsync("""
         var results = await state.DoStringAsync("""
                                                 function testLen(obj)
                                                 function testLen(obj)
-                                                    local ret=  #obj
-                                                    return ret
+                                                    return #obj
                                                 end
                                                 end
-                                                return testLen(TestObj.create(1, 2))
+                                                local obj = TestObj.create(1, 2)
+                                                return testLen(TestObj.create(1, 2)),-obj
                                                 """);
                                                 """);
-        Assert.That(results, Has.Length.EqualTo(1));
+        Assert.That(results, Has.Length.EqualTo(2));
         Assert.That(results[0].Read<double>(), Is.EqualTo(3));
         Assert.That(results[0].Read<double>(), Is.EqualTo(3));
+        Assert.That(results[1].Read<object>(), Is.TypeOf<LuaTestObj>());
+        var objUnm = results[1].Read<LuaTestObj>();
+        Assert.That(objUnm.X, Is.EqualTo(-1));
+        Assert.That(objUnm.Y, Is.EqualTo(-2));
+        
     }
     }
 }
 }

+ 183 - 0
tests/Lua.Tests/MetaTests.cs

@@ -0,0 +1,183 @@
+using Lua.Platforms;
+using Lua.Standard;
+using Lua.Tests.Helpers;
+using System.Globalization;
+
+namespace Lua.Tests;
+
+[LuaObject]
+partial class MetaFloat
+{
+    [LuaMember("value")] public double Value { get; set; }
+
+    [LuaMetamethod(LuaObjectMetamethod.Call)]
+    public static MetaFloat Create(LuaValue dummy, double value)
+    {
+        return new MetaFloat() { Value = value };
+    }
+
+    [LuaMetamethod(LuaObjectMetamethod.Add)]
+    public static MetaFloat Add(MetaFloat a, MetaFloat b)
+    {
+        return new MetaFloat() { Value = a.Value + b.Value };
+    }
+
+    [LuaMetamethod(LuaObjectMetamethod.Lt)]
+    public static bool Lt(MetaFloat a, MetaFloat b)
+    {
+        return a.Value < b.Value;
+    }
+
+    [LuaMetamethod(LuaObjectMetamethod.Len)]
+    public static double Len(MetaFloat a)
+    {
+        return a.Value;
+    }
+
+    [LuaMetamethod(LuaObjectMetamethod.ToString)]
+    public override string ToString()
+    {
+        return Value.ToString(CultureInfo.InvariantCulture);
+    }
+}
+
+[LuaObject]
+partial class MetaAsyncFloat
+{
+    [LuaMember("value")] public double Value { get; set; }
+
+    [LuaMetamethod(LuaObjectMetamethod.Call)]
+    public static async Task<MetaFloat> Create(LuaValue dummy, double value)
+    {
+        await Task.Delay(1);
+        return new MetaFloat() { Value = value };
+    }
+
+    [LuaMetamethod(LuaObjectMetamethod.Add)]
+    public static async Task<MetaFloat> Add(MetaFloat a, MetaFloat b)
+    {
+        await Task.Delay(1);
+        return new MetaFloat() { Value = a.Value + b.Value };
+    }
+
+    [LuaMetamethod(LuaObjectMetamethod.Lt)]
+    public static async Task<bool> Lt(MetaFloat a, MetaFloat b)
+    {
+        await Task.Delay(1);
+        return a.Value < b.Value;
+    }
+
+    [LuaMetamethod(LuaObjectMetamethod.Len)]
+    public static async Task<double> Len(MetaFloat a)
+    {
+        await Task.Delay(1);
+        return a.Value;
+    }
+
+    [LuaMetamethod(LuaObjectMetamethod.ToString)]
+    public override string ToString()
+    {
+        return Value.ToString(CultureInfo.InvariantCulture);
+    }
+}
+
+public class MetaTests
+{
+    const string TestFloatScript = """
+                                   local a = MetaFloat(10)
+                                   local b = MetaFloat(20)
+                                   local function test(x,y)
+                                     local z = x + y
+                                     local v = z.value
+                                     local len, str = #z, tostring(z)
+                                     return v, y > x ,len, str
+                                   end
+                                   local v, comp, len, str = test(a,b)
+                                   assert(v == 30)
+                                   assert(comp == true)
+                                   assert(len == 30)
+                                   assert(str == "30")
+                                   local c = a + b
+                                   return c.value, a < b,#c, tostring(c)
+                                   """;
+
+    const string TestIndexScript = """
+                                   local obj = setmetatable({}, {__index = getindentityMethod})
+                                   local a
+                                   local b = 1
+                                   local c =obj
+                                   a = obj:getindentity()
+                                   local d = obj.getindentity(obj)
+                                   print(d)
+                                   assert(a == obj)
+                                   assert(b == 1)
+                                   assert(obj == c)
+                                   assert(d == c)
+                                   """;
+
+    [Test]
+    public async Task TestMetaFloat()
+    {
+        var lua = LuaState.Create();
+        lua.OpenBasicLibrary();
+        lua.Environment["MetaFloat"] = new MetaFloat();
+        var result = await lua.DoStringAsync(TestFloatScript);
+        Assert.That(result.Length, Is.EqualTo(4));
+        Assert.That(result[0].Read<double>(), Is.EqualTo(30));
+        Assert.That(result[1].Read<bool>(), Is.EqualTo(true));
+        Assert.That(result[2].Read<double>(), Is.EqualTo(30));
+        Assert.That(result[3].Read<string>(), Is.EqualTo("30"));
+    }
+
+    [Test]
+    public async Task TestMetaAsyncFloat()
+    {
+        var lua = LuaState.Create();
+        lua.OpenBasicLibrary();
+        lua.Environment["MetaFloat"] = new MetaAsyncFloat();
+
+        var result = await lua.DoStringAsync(TestFloatScript);
+        Assert.That(result.Length, Is.EqualTo(4));
+        Assert.That(result[0].Read<double>(), Is.EqualTo(30));
+        Assert.That(result[1].Read<bool>(), Is.EqualTo(true));
+        Assert.That(result[2].Read<double>(), Is.EqualTo(30));
+        Assert.That(result[3].Read<string>(), Is.EqualTo("30"));
+    }
+
+    [Test]
+    public async Task TestMetaIndex()
+    {
+        var lua = LuaState.Create(LuaPlatform.Default with { StandardIO = new TestStandardIO() });
+        lua.OpenBasicLibrary();
+
+        lua.Environment["getindentityMethod"] = new LuaFunction("getindentity", (context, ct) =>
+        {
+            var obj = context.GetArgument(0);
+            return new(context.Return(new LuaFunction("getIndexed",
+                (ctx, ct2) =>
+                    new(ctx.Return(obj)))));
+        });
+
+
+        var result = await lua.DoStringAsync(TestIndexScript);
+    }
+
+    [Test]
+    public async Task TestMetaIndexAsync()
+    {
+        var lua = LuaState.Create(LuaPlatform.Default with { StandardIO = new TestStandardIO() });
+        lua.OpenBasicLibrary();
+
+
+        lua.Environment["getindentityMethod"] = new LuaFunction("getindentity", async (context, ct) =>
+        {
+            var obj = context.GetArgument(0);
+            await Task.Delay(1);
+            return (context.Return(new LuaFunction("getIndexed", (ctx, ct2)
+                => new(ctx.Return(obj)))));
+        });
+
+
+        var result = await lua.DoStringAsync(TestIndexScript);
+    }
+}