Browse Source

fix: use __call via other metamethods

Akeit0 7 months ago
parent
commit
5882dae637
3 changed files with 180 additions and 38 deletions
  1. 119 38
      src/Lua/Runtime/LuaVirtualMachine.cs
  2. 35 0
      tests/Lua.Tests/LuaApiTests.cs
  3. 26 0
      tests/Lua.Tests/MetatableTests.cs

+ 119 - 38
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -1011,26 +1011,37 @@ public static partial class LuaVirtualMachine
         if (vb.TryGetMetamethod(context.State, name, out var metamethod) ||
             vc.TryGetMetamethod(context.State, name, out metamethod))
         {
+            var stack = context.Stack;
+            var argCount = 2;
+            var callable = metamethod;
             if (!metamethod.TryReadFunction(out var func))
             {
-                LuaRuntimeException.AttemptInvalidOperation(GetThreadWithCurrentPc(context), "call", metamethod);
+                if (metamethod.TryGetMetamethod(context.State, Metamethods.Call, out metamethod) &&
+                    metamethod.TryReadFunction(out func))
+                {
+                    stack.Push(callable);
+                    argCount++;
+                }
+                else
+                {
+                    LuaRuntimeException.AttemptInvalidOperation(GetThreadWithCurrentPc(context), "call", metamethod);
+                }
             }
 
-            var stack = context.Stack;
             stack.Push(vb);
             stack.Push(vc);
-            var varArgCount = func.GetVariableArgumentCount(2);
+            var varArgCount = func.GetVariableArgumentCount(argCount);
 
-            var newFrame = func.CreateNewFrame(context, stack.Count - 2 + varArgCount, target, varArgCount);
+            var newFrame = func.CreateNewFrame(context, stack.Count - argCount + varArgCount, target, varArgCount);
 
             context.Thread.PushCallStackFrame(newFrame);
             if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook)
             {
-                await ExecuteCallHook(context, newFrame, 2);
+                await ExecuteCallHook(context, newFrame, argCount);
             }
 
 
-            await func.Invoke(context, newFrame, 2);
+            await func.Invoke(context, newFrame, argCount);
             stack.PopUntil(target + 1);
             context.Thread.PopCallStackFrame();
             context.PostOperation = PostOperationType.DontPop;
@@ -1682,23 +1693,34 @@ public static partial class LuaVirtualMachine
         if (vb.TryGetMetamethod(context.State, name, out var metamethod) ||
             vc.TryGetMetamethod(context.State, name, out metamethod))
         {
+            var stack = context.Stack;
+            var argCount = 2;
+            var callable = metamethod;
             if (!metamethod.TryReadFunction(out var func))
             {
-                LuaRuntimeException.AttemptInvalidOperation(GetThreadWithCurrentPc(context), "call", metamethod);
+                if (metamethod.TryGetMetamethod(context.State, Metamethods.Call, out metamethod) &&
+                    metamethod.TryReadFunction(out func))
+                {
+                    stack.Push(callable);
+                    argCount++;
+                }
+                else
+                {
+                    LuaRuntimeException.AttemptInvalidOperation(GetThreadWithCurrentPc(context), "call", metamethod);
+                }
             }
 
-            var stack = context.Stack;
             stack.Push(vb);
             stack.Push(vc);
-            var varArgCount = func.GetVariableArgumentCount(2);
+            var varArgCount = func.GetVariableArgumentCount(argCount);
 
-            var newFrame = func.CreateNewFrame(context, stack.Count - 2 + varArgCount, context.FrameBase + context.Instruction.A, varArgCount);
+            var newFrame = func.CreateNewFrame(context, stack.Count - argCount + varArgCount, context.FrameBase + context.Instruction.A, varArgCount);
 
             context.Thread.PushCallStackFrame(newFrame);
             if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook)
             {
                 context.PostOperation = PostOperationType.SetResult;
-                context.Task = ExecuteCallHook(context, newFrame, 2);
+                context.Task = ExecuteCallHook(context, newFrame, argCount);
                 doRestart = false;
                 return false;
             }
@@ -1711,7 +1733,7 @@ public static partial class LuaVirtualMachine
             }
 
 
-            var task = func.Invoke(context, newFrame, 2);
+            var task = func.Invoke(context, newFrame, argCount);
 
             if (!task.IsCompleted)
             {
@@ -1743,21 +1765,32 @@ public static partial class LuaVirtualMachine
         if (vb.TryGetMetamethod(thread.State, name, out var metamethod) ||
             vc.TryGetMetamethod(thread.State, name, out metamethod))
         {
+            var stack = thread.Stack;
+            var top = stack.Count;
+            var argCount = 2;
+            var callable = metamethod;
             if (!metamethod.TryReadFunction(out var func))
             {
-                LuaRuntimeException.AttemptInvalidOperation(thread, "call", metamethod);
+                if (metamethod.TryGetMetamethod(thread.State, Metamethods.Call, out metamethod) &&
+                    metamethod.TryReadFunction(out func))
+                {
+                    stack.Push(callable);
+                    argCount++;
+                }
+                else
+                {
+                    LuaRuntimeException.AttemptInvalidOperation(thread, "call", metamethod);
+                }
             }
 
-            var stack = thread.Stack;
-            var top = stack.Count;
             stack.Push(vb);
             stack.Push(vc);
-            var varArgCount = func.GetVariableArgumentCount(2);
+            var varArgCount = func.GetVariableArgumentCount(argCount);
 
-            var newFrame = new CallStackFrame() { Base = thread.Stack.Count - 2 + varArgCount, VariableArgumentCount = varArgCount, Function = func, ReturnBase = top };
+            var newFrame = new CallStackFrame() { Base = thread.Stack.Count - argCount + varArgCount, VariableArgumentCount = varArgCount, Function = func, ReturnBase = top };
 
             thread.PushCallStackFrame(newFrame);
-            var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = 2, ReturnFrameBase = top };
+            var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = argCount, ReturnFrameBase = top };
             if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
             {
                 await ExecuteCallHook(functionContext, ct);
@@ -1785,15 +1818,28 @@ public static partial class LuaVirtualMachine
         var stack = context.Stack;
         if (vb.TryGetMetamethod(context.State, name, out var metamethod))
         {
+            var argCount = 2;
+            var callable = metamethod;
             if (!metamethod.TryReadFunction(out var func))
             {
-                LuaRuntimeException.AttemptInvalidOperation(GetThreadWithCurrentPc(context), "call", metamethod);
+                if (metamethod.TryGetMetamethod(context.State, Metamethods.Call, out metamethod) &&
+                    metamethod.TryReadFunction(out func))
+                {
+                    stack.Push(callable);
+                    argCount++;
+                }
+                else
+                {
+                    LuaRuntimeException.AttemptInvalidOperation(GetThreadWithCurrentPc(context), "call", metamethod);
+                }
             }
 
+
             stack.Push(vb);
             stack.Push(vb);
-            var varArgCount = func.GetVariableArgumentCount(2);
-            var newFrame = func.CreateNewFrame(context, stack.Count - 2 + varArgCount, context.FrameBase + context.Instruction.A, varArgCount);
+            var varArgCount = func.GetVariableArgumentCount(argCount);
+
+            var newFrame = func.CreateNewFrame(context, stack.Count - argCount + varArgCount, context.FrameBase + context.Instruction.A, varArgCount);
 
             context.Thread.PushCallStackFrame(newFrame);
             if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook)
@@ -1847,20 +1893,32 @@ public static partial class LuaVirtualMachine
 
         if (vb.TryGetMetamethod(thread.State, name, out var metamethod))
         {
+            var stack = thread.Stack;
+            var top = stack.Count;
+            var argCount = 2;
+            var callable = metamethod;
             if (!metamethod.TryReadFunction(out var func))
             {
-                LuaRuntimeException.AttemptInvalidOperation(thread, "call", metamethod);
+                if (metamethod.TryGetMetamethod(thread.State, Metamethods.Call, out metamethod) &&
+                    metamethod.TryReadFunction(out func))
+                {
+                    stack.Push(callable);
+                    argCount++;
+                }
+                else
+                {
+                    LuaRuntimeException.AttemptInvalidOperation(thread, "call", metamethod);
+                }
             }
 
-            var stack = thread.Stack;
-            var top = stack.Count;
             stack.Push(vb);
-            var varArgCount = func.GetVariableArgumentCount(1);
+            stack.Push(vb);
+            var varArgCount = func.GetVariableArgumentCount(argCount);
 
-            var newFrame = new CallStackFrame() { Base = thread.Stack.Count - 1 + varArgCount, VariableArgumentCount = varArgCount, Function = func, ReturnBase = top };
+            var newFrame = new CallStackFrame() { Base = thread.Stack.Count - argCount + varArgCount, VariableArgumentCount = varArgCount, Function = func, ReturnBase = top };
 
             thread.PushCallStackFrame(newFrame);
-            var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = 2, ReturnFrameBase = top };
+            var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = argCount, ReturnFrameBase = top };
             if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
             {
                 await ExecuteCallHook(functionContext, ct);
@@ -1890,21 +1948,33 @@ public static partial class LuaVirtualMachine
         if (vb.TryGetMetamethod(context.State, name, out var metamethod) ||
             vc.TryGetMetamethod(context.State, name, out metamethod))
         {
+            var stack = context.Stack;
+            var argCount = 2;
+            var callable = metamethod;
             if (!metamethod.TryReadFunction(out var func))
             {
-                LuaRuntimeException.AttemptInvalidOperation(GetThreadWithCurrentPc(context), "call", metamethod);
+                if (metamethod.TryGetMetamethod(context.State, Metamethods.Call, out metamethod) &&
+                    metamethod.TryReadFunction(out func))
+                {
+                    stack.Push(callable);
+                    argCount++;
+                }
+                else
+                {
+                    LuaRuntimeException.AttemptInvalidOperation(GetThreadWithCurrentPc(context), "call", metamethod);
+                }
             }
 
-            var stack = context.Stack;
             stack.Push(vb);
             stack.Push(vc);
-            var newFrame = func.CreateNewFrame(context, stack.Count - 2);
+            var varArgCount = func.GetVariableArgumentCount(argCount);
+            var newFrame = func.CreateNewFrame(context, stack.Count - argCount + varArgCount);
             if (reverseLe) newFrame.Flags |= CallStackFrameFlags.ReversedLe;
             context.Thread.PushCallStackFrame(newFrame);
             if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook)
             {
                 context.PostOperation = PostOperationType.Compare;
-                context.Task = ExecuteCallHook(context, newFrame, 2);
+                context.Task = ExecuteCallHook(context, newFrame, argCount);
                 doRestart = false;
                 return false;
             }
@@ -1916,7 +1986,7 @@ public static partial class LuaVirtualMachine
                 return true;
             }
 
-            var task = func.Invoke(context, newFrame, 2);
+            var task = func.Invoke(context, newFrame, argCount);
 
             if (!task.IsCompleted)
             {
@@ -1976,21 +2046,32 @@ public static partial class LuaVirtualMachine
         if (vb.TryGetMetamethod(thread.State, name, out var metamethod) ||
             vc.TryGetMetamethod(thread.State, name, out metamethod))
         {
+            var stack = thread.Stack;
+            var top = stack.Count;
+            var argCount = 2;
+            var callable = metamethod;
             if (!metamethod.TryReadFunction(out var func))
             {
-                LuaRuntimeException.AttemptInvalidOperation(thread, "call", metamethod);
+                if (metamethod.TryGetMetamethod(thread.State, Metamethods.Call, out metamethod) &&
+                    metamethod.TryReadFunction(out func))
+                {
+                    stack.Push(callable);
+                    argCount++;
+                }
+                else
+                {
+                    LuaRuntimeException.AttemptInvalidOperation(thread, "call", metamethod);
+                }
             }
 
-            var stack = thread.Stack;
-            var top = stack.Count;
             stack.Push(vb);
             stack.Push(vc);
-            var varArgCount = func.GetVariableArgumentCount(2);
+            var varArgCount = func.GetVariableArgumentCount(argCount);
 
-            var newFrame = new CallStackFrame() { Base = thread.Stack.Count - 2 + varArgCount, VariableArgumentCount = varArgCount, Function = func, ReturnBase = top };
+            var newFrame = new CallStackFrame() { Base = thread.Stack.Count - argCount + varArgCount, VariableArgumentCount = varArgCount, Function = func, ReturnBase = top };
 
             thread.PushCallStackFrame(newFrame);
-            var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = 2, ReturnFrameBase = top };
+            var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = argCount, ReturnFrameBase = top };
             if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
             {
                 await ExecuteCallHook(functionContext, ct);

+ 35 - 0
tests/Lua.Tests/LuaApiTests.cs

@@ -202,4 +202,39 @@ return a,b,c
             Assert.That(table[9].Read<double>(), Is.EqualTo(9));
         });
     }
+
+    [Test]
+    public async Task Test_Metamethod_MetaCallViaMeta()
+    {
+        var source = """
+                     local a = {name ="a"}
+                     setmetatable(a, {
+                         __call = function(a, b, c)
+                             return a.name..b.name..c.name
+                         end
+                     })
+
+
+                     local b = setmetatable({name="b"},
+                       {__unm = a,
+                       __add= a,
+                       __concat =a
+                       
+                       })
+                     local c ={name ="c"}
+                     return b,c
+                     """;
+        var result = await state.DoStringAsync(source);
+        var b = result[0];
+        var c = result[1];
+        var d = await state.MainThread.OpArithmetic(b, c, OpCode.Add);
+        Assert.True(d.TryRead(out string s));
+        Assert.That(s, Is.EqualTo("abc"));
+        d = await state.MainThread.OpUnary(b, OpCode.Unm);
+        Assert.True(d.TryRead(out s));
+        Assert.That(s, Is.EqualTo("abb"));
+        d = await state.MainThread.OpConcat([c, b]);
+        Assert.True(d.TryRead(out s));
+        Assert.That(s, Is.EqualTo("acb"));
+    }
 }

+ 26 - 0
tests/Lua.Tests/MetatableTests.cs

@@ -188,4 +188,30 @@ end
         Assert.That(r, Has.Length.EqualTo(1));
         Assert.That(r[0].Read<LuaTable>()[1].Read<string>(), Does.Contain("stack traceback:"));
     }
+
+    [Test]
+    public async Task Test_Metamethod_MetaCallViaMeta()
+    {
+        var source = """
+                     local a = {name ="a"}
+                     setmetatable(a, {
+                         __call = function(a, b, c)
+                             return a.name..b.name..c.name
+                         end
+                     })
+
+
+                     local b = setmetatable({name="b"},
+                       {__unm = a,
+                       __add= a,
+                       __concat =a
+                       
+                       })
+                     local c ={name ="c"}
+                     assert((b + c)== "abc")
+                     assert((b .. c)== "abc")
+                     """;
+        await state.DoStringAsync(source);
+    }
+
 }