Browse Source

Merge pull request #99 from Akeit0/fix-metamethod-__call

Fix: Self is not passed to `__call` metamethod
Akeit0 8 months ago
parent
commit
5b706bc13b

+ 82 - 36
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -96,15 +96,15 @@ public static partial class LuaVirtualMachine
             switch (opCode)
             {
                 case OpCode.Call:
+                {
+                    var c = callInstruction.C;
+                    if (c != 0)
                     {
-                        var c = callInstruction.C;
-                        if (c != 0)
-                        {
-                            targetCount = c - 1;
-                        }
-
-                        break;
+                        targetCount = c - 1;
                     }
+
+                    break;
+                }
                 case OpCode.TForCall:
                     target += 3;
                     targetCount = callInstruction.C;
@@ -1034,11 +1034,15 @@ public static partial class LuaVirtualMachine
         var instruction = context.Instruction;
         var RA = instruction.A + context.FrameBase;
         var va = context.Stack.Get(RA);
+        var newBase = RA + 1;
+        bool isMetamethod = false;
         if (!va.TryReadFunction(out var func))
         {
             if (va.TryGetMetamethod(context.State, Metamethods.Call, out var metamethod) &&
                 metamethod.TryReadFunction(out func))
             {
+                newBase -= 1;
+                isMetamethod = true;
             }
             else
             {
@@ -1047,8 +1051,8 @@ public static partial class LuaVirtualMachine
         }
 
         var thread = context.Thread;
-        var (newBase, argumentCount, variableArgumentCount) = PrepareForFunctionCall(thread, func, instruction, RA);
-
+        var (argumentCount, variableArgumentCount) = PrepareForFunctionCall(thread, func, instruction, newBase, isMetamethod);
+        newBase += variableArgumentCount;
         var newFrame = func.CreateNewFrame(ref context, newBase, variableArgumentCount);
 
         thread.PushCallStackFrame(newFrame);
@@ -1160,6 +1164,8 @@ public static partial class LuaVirtualMachine
         var instruction = context.Instruction;
         var stack = context.Stack;
         var RA = instruction.A + context.FrameBase;
+        var newBase = RA + 1;
+        bool isMetamethod = false;
         var state = context.State;
         var thread = context.Thread;
 
@@ -1168,19 +1174,23 @@ public static partial class LuaVirtualMachine
         var va = stack.Get(RA);
         if (!va.TryReadFunction(out var func))
         {
-            if (!va.TryGetMetamethod(state, Metamethods.Call, out var metamethod) &&
-                !metamethod.TryReadFunction(out func))
+            if (va.TryGetMetamethod(state, Metamethods.Call, out var metamethod) &&
+                metamethod.TryReadFunction(out func))
+            {
+                isMetamethod = true;
+                newBase -= 1;
+            }
+            else
             {
                 LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "call", metamethod);
             }
         }
 
-        var (newBase, argumentCount, variableArgumentCount) = PrepareForFunctionTailCall(thread, func, instruction, RA);
-
+        var (argumentCount, variableArgumentCount) = PrepareForFunctionTailCall(thread, func, instruction, newBase, isMetamethod);
+        newBase = context.FrameBase + variableArgumentCount;
 
         var lastPc = thread.CallStack.AsSpan()[^1].CallerInstructionIndex;
         context.Thread.PopCallStackFrameUnsafe();
-
         var newFrame = func.CreateNewFrame(ref context, newBase, variableArgumentCount);
 
         newFrame.Flags |= CallStackFrameFlags.TailCall;
@@ -1233,18 +1243,46 @@ public static partial class LuaVirtualMachine
         var instruction = context.Instruction;
         var stack = context.Stack;
         var RA = instruction.A + context.FrameBase;
-
+        bool isMetamethod = false;
         var iteratorRaw = stack.Get(RA);
         if (!iteratorRaw.TryReadFunction(out var iterator))
         {
-            LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "call", iteratorRaw);
+            if (iteratorRaw.TryGetMetamethod(context.State, Metamethods.Call, out var metamethod) &&
+                metamethod.TryReadFunction(out iterator))
+            {
+                isMetamethod = true;
+            }
+            else
+            {
+                LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "call", metamethod);
+            }
         }
 
         var newBase = RA + 3 + instruction.C;
-        stack.Get(newBase) = stack.Get(RA + 1);
-        stack.Get(newBase + 1) = stack.Get(RA + 2);
-        stack.NotifyTop(newBase + 2);
-        var newFrame = iterator.CreateNewFrame(ref context, newBase);
+
+        if (isMetamethod)
+        {
+            stack.Get(newBase) = iteratorRaw;
+            stack.Get(newBase + 1) = stack.Get(RA + 1);
+            stack.Get(newBase + 2) = stack.Get(RA + 2);
+            stack.NotifyTop(newBase + 3);
+        }
+        else
+        {
+            stack.Get(newBase) = stack.Get(RA + 1);
+            stack.Get(newBase + 1) = stack.Get(RA + 2);
+            stack.NotifyTop(newBase + 2);
+        }
+
+        var argumentCount = isMetamethod ? 3 : 2;
+        var variableArgumentCount = iterator.GetVariableArgumentCount(argumentCount);
+        if (variableArgumentCount != 0)
+        {
+            PrepareVariableArgument(stack, newBase, argumentCount, variableArgumentCount);
+            newBase += variableArgumentCount;
+        }
+
+        var newFrame = iterator.CreateNewFrame(ref context, newBase, variableArgumentCount);
         context.Thread.PushCallStackFrame(newFrame);
         if (iterator is LuaClosure)
         {
@@ -1370,7 +1408,7 @@ public static partial class LuaVirtualMachine
             }
 
             table = metatableValue;
-        Function:
+            Function:
             if (table.TryReadFunction(out var function))
             {
                 return CallGetTableFunc(targetTable, function, key, ref context, out value, out doRestart);
@@ -1468,7 +1506,7 @@ public static partial class LuaVirtualMachine
 
             table = metatableValue;
 
-        Function:
+            Function:
             if (table.TryReadFunction(out var function))
             {
                 context.PostOperation = PostOperationType.Nop;
@@ -1732,7 +1770,7 @@ public static partial class LuaVirtualMachine
     // If there are variable arguments, the base of the stack is moved by that number and the values of the variable arguments are placed in front of it.
     // see: https://wubingzheng.github.io/build-lua-in-rust/en/ch08-02.arguments.html
     [MethodImpl(MethodImplOptions.NoInlining)]
-    static (int FrameBase, int ArgumentCount, int VariableArgumentCount) PrepareVariableArgument(LuaStack stack, int newBase, int argumentCount,
+    static ( int ArgumentCount, int VariableArgumentCount) PrepareVariableArgument(LuaStack stack, int newBase, int argumentCount,
         int variableArgumentCount)
     {
         var temp = newBase;
@@ -1744,51 +1782,59 @@ public static partial class LuaVirtualMachine
         var stackBuffer = stack.GetBuffer()[temp..];
         stackBuffer[..argumentCount].CopyTo(stackBuffer[variableArgumentCount..]);
         stackBuffer.Slice(argumentCount, variableArgumentCount).CopyTo(stackBuffer);
-        return (newBase, argumentCount, variableArgumentCount);
+        return (argumentCount, variableArgumentCount);
     }
 
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    static (int FrameBase, int ArgumentCount, int VariableArgumentCount) PrepareForFunctionCall(LuaThread thread, LuaFunction function,
-        Instruction instruction, int RA)
+    static (int ArgumentCount, int VariableArgumentCount) PrepareForFunctionCall(LuaThread thread, LuaFunction function,
+        Instruction instruction, int newBase, bool isMetaMethod)
     {
         var argumentCount = instruction.B - 1;
         if (argumentCount == -1)
         {
-            argumentCount = (ushort)(thread.Stack.Count - (RA + 1));
+            argumentCount = (ushort)(thread.Stack.Count - newBase);
         }
         else
         {
-            thread.Stack.NotifyTop(RA + 1 + argumentCount);
+            if (isMetaMethod)
+            {
+                argumentCount += 1;
+            }
+
+            thread.Stack.NotifyTop(newBase + argumentCount);
         }
 
-        var newBase = RA + 1;
         var variableArgumentCount = function.GetVariableArgumentCount(argumentCount);
 
         if (variableArgumentCount <= 0)
         {
-            return (newBase, argumentCount, 0);
+            return (argumentCount, 0);
         }
 
         return PrepareVariableArgument(thread.Stack, newBase, argumentCount, variableArgumentCount);
     }
 
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    static (int FrameBase, int ArgumentCount, int VariableArgumentCount) PrepareForFunctionTailCall(LuaThread thread, LuaFunction function,
-        Instruction instruction, int RA)
+    static (int ArgumentCount, int VariableArgumentCount) PrepareForFunctionTailCall(LuaThread thread, LuaFunction function,
+        Instruction instruction, int newBase, bool isMetaMethod)
     {
         var stack = thread.Stack;
 
         var argumentCount = instruction.B - 1;
         if (instruction.B == 0)
         {
-            argumentCount = (ushort)(stack.Count - (RA + 1));
+            argumentCount = (ushort)(stack.Count - newBase);
         }
         else
         {
-            thread.Stack.NotifyTop(RA + 1 + argumentCount);
+            if (isMetaMethod)
+            {
+                argumentCount += 1;
+            }
+
+            thread.Stack.NotifyTop(newBase + argumentCount);
         }
 
-        var newBase = RA + 1;
 
         // In the case of tailcall, the local variables of the caller are immediately discarded, so there is no need to retain them.
         // Therefore, a call can be made without allocating new registers.
@@ -1804,7 +1850,7 @@ public static partial class LuaVirtualMachine
 
         if (variableArgumentCount <= 0)
         {
-            return (newBase, argumentCount, 0);
+            return (argumentCount, 0);
         }
 
         return PrepareVariableArgument(thread.Stack, newBase, argumentCount, variableArgumentCount);

+ 25 - 0
tests/Lua.Tests/LuaObjectTests.cs

@@ -1,3 +1,5 @@
+using Lua.Standard;
+
 namespace Lua.Tests;
 
 [LuaObject]
@@ -32,6 +34,12 @@ public partial class TestUserData
     {
         return Property;
     }
+
+    [LuaMetamethod(LuaObjectMetamethod.Call)]
+    public string Call()
+    {
+        return "Called!";
+    }
 }
 
 public class LuaObjectTests
@@ -120,4 +128,21 @@ public class LuaObjectTests
         Assert.That(results, Has.Length.EqualTo(1));
         Assert.That(results[0], Is.EqualTo(new LuaValue(1)));
     }
+
+    [Test]
+    public async Task Test_CallMetamethod()
+    {
+        var userData = new TestUserData();
+
+        var state = LuaState.Create();
+        state.OpenBasicLibrary();
+        state.Environment["test"] = userData;
+        var results = await state.DoStringAsync("""
+                                                assert(test() == 'Called!')
+                                                return test()
+                                                """);
+
+        Assert.That(results, Has.Length.EqualTo(1));
+        Assert.That(results[0], Is.EqualTo(new LuaValue("Called!")));
+    }
 }

+ 50 - 0
tests/Lua.Tests/MetatableTests.cs

@@ -87,6 +87,56 @@ assert(metatable.__newindex.x == 2)
         await state.DoStringAsync(source);
     }
 
+    [Test]
+    public async Task Test_Metamethod_Call()
+    {
+        var source = @"
+metatable = {
+    __call = function(a, b)
+        return a.x + b
+    end
+}
+
+local a = {}
+a.x = 1
+setmetatable(a, metatable)
+assert(a(2) == 3)
+function tail(a, b)
+    return a(b)
+end
+tail(a, 3)
+assert(tail(a, 3) == 4)
+";
+        await state.DoStringAsync(source);
+    }
+    
+    [Test]
+    public async Task Test_Metamethod_TForCall()
+    {
+        var source = @"
+local i =3
+function a(...)
+  local v ={...}
+   assert(v[1] ==t)
+   assert(v[2] == nil)
+   if i ==3 then
+       assert(v[3] == nil)
+    else
+      assert(v[3] == i)
+    end
+   
+   i  =i -1
+   if i ==0 then return nil end
+   return i
+end
+
+t =setmetatable({},{__call = a})
+
+for i in t do 
+end
+";
+        await state.DoStringAsync(source);
+    }
     [Test]
     public async Task Test_Hook_Metamethods()
     {