Browse Source

Fix: Self is not passed to `__call` metamethod

Akeit0 10 months ago
parent
commit
6bcfa1be1f

+ 47 - 28
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -92,15 +92,15 @@ public static partial class LuaVirtualMachine
             switch (opCode)
             {
                 case OpCode.Call:
+                {
+                    var c = callInstruction.C;
+                    if (c != 0)
                     {
-                        var c = callInstruction.C;
-                        if (c != 0)
-                        {
-                            targetCount = c - 1;
-                        }
-
-                        break;
+                        targetCount = c - 1;
                     }
+
+                    break;
+                }
                 case OpCode.TForCall:
                     target += 3;
                     targetCount = callInstruction.C;
@@ -265,7 +265,7 @@ public static partial class LuaVirtualMachine
 
         try
         {
-        // This is a label to restart the execution when new function is called or restarted
+            // This is a label to restart the execution when new function is called or restarted
         Restart:
             ref var instructionsHead = ref context.Chunk.Instructions[0];
             var frameBase = context.FrameBase;
@@ -1014,11 +1014,15 @@ public static partial class LuaVirtualMachine
         var instruction = context.Instruction;
         var RA = instruction.A + context.FrameBase;
         var va = context.Stack.Get(RA);
+        var newBase = RA + 1;
+        bool isMetamethod = false;
         if (!va.TryReadFunction(out var func))
         {
             if (va.TryGetMetamethod(context.State, Metamethods.Call, out var metamethod) &&
                 metamethod.TryReadFunction(out func))
             {
+                newBase -= 1;
+                isMetamethod = true;
             }
             else
             {
@@ -1027,8 +1031,8 @@ public static partial class LuaVirtualMachine
         }
 
         var thread = context.Thread;
-        var (newBase, argumentCount, variableArgumentCount) = PrepareForFunctionCall(thread, func, instruction, RA);
-
+        var (argumentCount, variableArgumentCount) = PrepareForFunctionCall(thread, func, instruction, newBase, isMetamethod);
+        newBase += variableArgumentCount;
         var newFrame = func.CreateNewFrame(ref context, newBase, variableArgumentCount);
 
         thread.PushCallStackFrame(newFrame);
@@ -1131,6 +1135,8 @@ public static partial class LuaVirtualMachine
         var instruction = context.Instruction;
         var stack = context.Stack;
         var RA = instruction.A + context.FrameBase;
+        var newBase = RA + 1;
+        bool isMetamethod = false;
         var state = context.State;
         var thread = context.Thread;
 
@@ -1139,15 +1145,20 @@ public static partial class LuaVirtualMachine
         var va = stack.Get(RA);
         if (!va.TryReadFunction(out var func))
         {
-            if (!va.TryGetMetamethod(state, Metamethods.Call, out var metamethod) &&
-                !metamethod.TryReadFunction(out func))
+            if (va.TryGetMetamethod(state, Metamethods.Call, out var metamethod) &&
+                metamethod.TryReadFunction(out func))
+            {
+                isMetamethod = true;
+                newBase -= 1;
+            }
+            else
             {
                 LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "call", metamethod);
             }
         }
 
-        var (newBase, argumentCount, variableArgumentCount) = PrepareForFunctionTailCall(thread, func, instruction, RA);
-
+        var (argumentCount, variableArgumentCount) = PrepareForFunctionTailCall(thread, func, instruction, newBase, isMetamethod);
+        newBase = context.FrameBase + variableArgumentCount;
         var newFrame = func.CreateNewFrame(ref context, newBase, variableArgumentCount);
         thread.PushCallStackFrame(newFrame);
 
@@ -1646,7 +1657,7 @@ public static partial class LuaVirtualMachine
     // If there are variable arguments, the base of the stack is moved by that number and the values of the variable arguments are placed in front of it.
     // see: https://wubingzheng.github.io/build-lua-in-rust/en/ch08-02.arguments.html
     [MethodImpl(MethodImplOptions.NoInlining)]
-    static (int FrameBase, int ArgumentCount, int VariableArgumentCount) PrepareVariableArgument(LuaStack stack, int newBase, int argumentCount,
+    static ( int ArgumentCount, int VariableArgumentCount) PrepareVariableArgument(LuaStack stack, int newBase, int argumentCount,
         int variableArgumentCount)
     {
         var temp = newBase;
@@ -1658,51 +1669,59 @@ public static partial class LuaVirtualMachine
         var stackBuffer = stack.GetBuffer()[temp..];
         stackBuffer[..argumentCount].CopyTo(stackBuffer[variableArgumentCount..]);
         stackBuffer.Slice(argumentCount, variableArgumentCount).CopyTo(stackBuffer);
-        return (newBase, argumentCount, variableArgumentCount);
+        return (argumentCount, variableArgumentCount);
     }
 
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    static (int FrameBase, int ArgumentCount, int VariableArgumentCount) PrepareForFunctionCall(LuaThread thread, LuaFunction function,
-        Instruction instruction, int RA)
+    static (int ArgumentCount, int VariableArgumentCount) PrepareForFunctionCall(LuaThread thread, LuaFunction function,
+        Instruction instruction, int newBase, bool isMetaMethod)
     {
         var argumentCount = instruction.B - 1;
         if (argumentCount == -1)
         {
-            argumentCount = (ushort)(thread.Stack.Count - (RA + 1));
+            argumentCount = (ushort)(thread.Stack.Count - newBase);
         }
         else
         {
-            thread.Stack.NotifyTop(RA + 1 + argumentCount);
+            if (isMetaMethod)
+            {
+                argumentCount += 1;
+            }
+
+            thread.Stack.NotifyTop(newBase + argumentCount);
         }
 
-        var newBase = RA + 1;
         var variableArgumentCount = function.GetVariableArgumentCount(argumentCount);
 
         if (variableArgumentCount <= 0)
         {
-            return (newBase, argumentCount, 0);
+            return (argumentCount, 0);
         }
 
         return PrepareVariableArgument(thread.Stack, newBase, argumentCount, variableArgumentCount);
     }
 
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    static (int FrameBase, int ArgumentCount, int VariableArgumentCount) PrepareForFunctionTailCall(LuaThread thread, LuaFunction function,
-        Instruction instruction, int RA)
+    static (int ArgumentCount, int VariableArgumentCount) PrepareForFunctionTailCall(LuaThread thread, LuaFunction function,
+        Instruction instruction, int newBase, bool isMetaMethod)
     {
         var stack = thread.Stack;
 
         var argumentCount = instruction.B - 1;
         if (instruction.B == 0)
         {
-            argumentCount = (ushort)(stack.Count - (RA + 1));
+            argumentCount = (ushort)(stack.Count - newBase);
         }
         else
         {
-            thread.Stack.NotifyTop(RA + 1 + argumentCount);
+            if (isMetaMethod)
+            {
+                argumentCount += 1;
+            }
+
+            thread.Stack.NotifyTop(newBase + argumentCount);
         }
 
-        var newBase = RA + 1;
 
         // In the case of tailcall, the local variables of the caller are immediately discarded, so there is no need to retain them.
         // Therefore, a call can be made without allocating new registers.
@@ -1718,7 +1737,7 @@ public static partial class LuaVirtualMachine
 
         if (variableArgumentCount <= 0)
         {
-            return (newBase, argumentCount, 0);
+            return (argumentCount, 0);
         }
 
         return PrepareVariableArgument(thread.Stack, newBase, argumentCount, variableArgumentCount);

+ 29 - 2
tests/Lua.Tests/LuaObjectTests.cs

@@ -1,3 +1,5 @@
+using Lua.Standard;
+
 namespace Lua.Tests;
 
 [LuaObject]
@@ -24,6 +26,8 @@ public partial class TestUserData
     [LuaMember]
     public static double StaticMethodWithReturnValue(double a, double b)
     {
+        Console.WriteLine(a);
+        Console.WriteLine(b);
         return a + b;
     }
 
@@ -32,6 +36,12 @@ public partial class TestUserData
     {
         return Property;
     }
+
+    [LuaMetamethod(LuaObjectMetamethod.Call)]
+    public string Call()
+    {
+        return "Called!";
+    }
 }
 
 public class LuaObjectTests
@@ -99,10 +109,10 @@ public class LuaObjectTests
 
         var state = LuaState.Create();
         state.Environment["test"] = userData;
-        var results = await state.DoStringAsync("return test.StaticMethodWithReturnValue(1, 2)");
+        var results = await state.DoStringAsync("return test.StaticMethodWithReturnValue(1, -2)");
 
         Assert.That(results, Has.Length.EqualTo(1));
-        Assert.That(results[0], Is.EqualTo(new LuaValue(3)));
+        Assert.That(results[0], Is.EqualTo(new LuaValue(-1)));
     }
 
     [Test]
@@ -120,4 +130,21 @@ public class LuaObjectTests
         Assert.That(results, Has.Length.EqualTo(1));
         Assert.That(results[0], Is.EqualTo(new LuaValue(1)));
     }
+
+    [Test]
+    public async Task Test_CallMetamethod()
+    {
+        var userData = new TestUserData();
+
+        var state = LuaState.Create();
+        state.OpenBasicLibrary();
+        state.Environment["test"] = userData;
+        var results = await state.DoStringAsync("""
+                                                assert(test() == 'Called!')
+                                                return test()
+                                                """);
+
+        Assert.That(results, Has.Length.EqualTo(1));
+        Assert.That(results[0], Is.EqualTo(new LuaValue("Called!")));
+    }
 }

+ 23 - 0
tests/Lua.Tests/MetatableTests.cs

@@ -83,6 +83,29 @@ a.x = nil
 a.x = 2
 assert(a.x == nil)
 assert(metatable.__newindex.x == 2)
+";
+        await state.DoStringAsync(source);
+    }
+
+    [Test]
+    public async Task Test_Metamethod_Call()
+    {
+        var source = @"
+metatable = {
+    __call = function(a, b)
+        return a.x + b
+    end
+}
+
+local a = {}
+a.x = 1
+setmetatable(a, metatable)
+assert(a(2) == 3)
+function tail(a, b)
+    return a(b)
+end
+tail(a, 3)
+assert(tail(a, 3) == 4)
 ";
         await state.DoStringAsync(source);
     }