Browse Source

feat: implement OpCall method for LuaThread and fix vararg handling

Akeit0 7 months ago
parent
commit
9a259854a9
3 changed files with 160 additions and 41 deletions
  1. 26 5
      src/Lua/LuaThreadExtensions.cs
  2. 85 33
      src/Lua/Runtime/LuaVirtualMachine.cs
  3. 49 3
      tests/Lua.Tests/LuaApiTests.cs

+ 26 - 5
src/Lua/LuaThreadExtensions.cs

@@ -264,15 +264,36 @@ public static class LuaThreadExtensions
 
 
         return LuaVirtualMachine.ExecuteSetTableSlowPath(thread, table, key, value, ct);
         return LuaVirtualMachine.ExecuteSetTableSlowPath(thread, table, key, value, ct);
     }
     }
-    
-    public static ValueTask<LuaValue> OpConcat(this LuaThread thread, ReadOnlySpan<LuaValue> values,CancellationToken ct = default)
+
+    public static ValueTask<LuaValue> OpConcat(this LuaThread thread, ReadOnlySpan<LuaValue> values, CancellationToken ct = default)
     {
     {
         thread.Stack.PushRange(values);
         thread.Stack.PushRange(values);
         return OpConcat(thread, values.Length, ct);
         return OpConcat(thread, values.Length, ct);
     }
     }
-    public static ValueTask<LuaValue> OpConcat(this LuaThread thread, int concatCount,CancellationToken ct = default)
+
+    public static ValueTask<LuaValue> OpConcat(this LuaThread thread, int concatCount, CancellationToken ct = default)
+    {
+        return LuaVirtualMachine.Concat(thread, concatCount, ct);
+    }
+
+    public static ValueTask<int> OpCall(this LuaThread thread, int funcIndex, CancellationToken ct = default)
     {
     {
-        
-        return LuaVirtualMachine.Concat(thread,  concatCount, ct);
+        return LuaVirtualMachine.Call(thread, funcIndex, ct);
+    }
+
+    public static ValueTask<LuaValue[]> OpCall(this LuaThread thread, LuaValue function, ReadOnlySpan<LuaValue> args, CancellationToken ct = default)
+    {
+        var funcIndex = thread.Stack.Count;
+        thread.Stack.Push(function);
+        thread.Stack.PushRange(args);
+        return Impl(thread, funcIndex, ct);
+
+        static async ValueTask<LuaValue[]> Impl(LuaThread thread, int funcIndex, CancellationToken ct)
+        {
+            await LuaVirtualMachine.Call(thread, funcIndex, ct);
+            var count = thread.Stack.Count - funcIndex;
+            using var results = thread.ReadReturnValues(count);
+            return results.AsSpan().ToArray();
+        }
     }
     }
 }
 }

+ 85 - 33
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -1131,6 +1131,41 @@ public static partial class LuaVirtualMachine
         }
         }
     }
     }
 
 
+    internal static async ValueTask<int> Call(LuaThread thread, int funcIndex, CancellationToken ct)
+    {
+        var stack = thread.Stack;
+        var newBase = funcIndex + 1;
+        var va = stack.Get(funcIndex);
+        if (!va.TryReadFunction(out var func))
+        {
+            if (va.TryGetMetamethod(thread.State, Metamethods.Call, out va) &&
+                va.TryReadFunction(out func))
+            {
+                newBase--;
+            }
+            else
+            {
+                LuaRuntimeException.AttemptInvalidOperation(thread, "call", va);
+            }
+        }
+
+        var (argCount, variableArgumentCount) = PrepareForFunctionCall(thread, func, newBase);
+        newBase += variableArgumentCount;
+        var newFrame = new CallStackFrame() { Base = newBase, VariableArgumentCount = variableArgumentCount, Function = func, ReturnBase = funcIndex };
+
+        thread.PushCallStackFrame(newFrame);
+        var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = argCount, ReturnFrameBase = funcIndex };
+        if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
+        {
+            await ExecuteCallHook(functionContext, ct);
+        }
+
+
+        await func.Func(functionContext, ct);
+        thread.PopCallStackFrame();
+        return thread.Stack.Count - funcIndex;
+    }
+
     static void CallPostOperation(VirtualMachineExecutionContext context)
     static void CallPostOperation(VirtualMachineExecutionContext context)
     {
     {
         var instruction = context.Instruction;
         var instruction = context.Instruction;
@@ -1494,12 +1529,12 @@ public static partial class LuaVirtualMachine
         var top = stack.Count;
         var top = stack.Count;
         stack.Push(table);
         stack.Push(table);
         stack.Push(key);
         stack.Push(key);
-        var varArgCount = indexTable.GetVariableArgumentCount(3);
+        var varArgCount = indexTable.GetVariableArgumentCount(2);
 
 
         var newFrame = new CallStackFrame() { Base = thread.Stack.Count - 2 + varArgCount, VariableArgumentCount = varArgCount, Function = indexTable, ReturnBase = top };
         var newFrame = new CallStackFrame() { Base = thread.Stack.Count - 2 + varArgCount, VariableArgumentCount = varArgCount, Function = indexTable, ReturnBase = top };
 
 
         thread.PushCallStackFrame(newFrame);
         thread.PushCallStackFrame(newFrame);
-        var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = 3, ReturnFrameBase = top };
+        var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = 2, ReturnFrameBase = top };
         if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
         if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
         {
         {
             await ExecuteCallHook(functionContext, ct);
             await ExecuteCallHook(functionContext, ct);
@@ -1694,7 +1729,7 @@ public static partial class LuaVirtualMachine
             vc.TryGetMetamethod(context.State, name, out metamethod))
             vc.TryGetMetamethod(context.State, name, out metamethod))
         {
         {
             var stack = context.Stack;
             var stack = context.Stack;
-            var argCount = 2;
+            var newBase = stack.Count;
             var callable = metamethod;
             var callable = metamethod;
             if (!metamethod.TryReadFunction(out var func))
             if (!metamethod.TryReadFunction(out var func))
             {
             {
@@ -1702,7 +1737,6 @@ public static partial class LuaVirtualMachine
                     metamethod.TryReadFunction(out func))
                     metamethod.TryReadFunction(out func))
                 {
                 {
                     stack.Push(callable);
                     stack.Push(callable);
-                    argCount++;
                 }
                 }
                 else
                 else
                 {
                 {
@@ -1712,9 +1746,9 @@ public static partial class LuaVirtualMachine
 
 
             stack.Push(vb);
             stack.Push(vb);
             stack.Push(vc);
             stack.Push(vc);
-            var varArgCount = func.GetVariableArgumentCount(argCount);
-
-            var newFrame = func.CreateNewFrame(context, stack.Count - argCount + varArgCount, context.FrameBase + context.Instruction.A, varArgCount);
+            var (argCount, variableArgumentCount) = PrepareForFunctionCall(context.Thread, func, newBase);
+            newBase += variableArgumentCount;
+            var newFrame = func.CreateNewFrame(context, newBase, context.FrameBase + context.Instruction.A, variableArgumentCount);
 
 
             context.Thread.PushCallStackFrame(newFrame);
             context.Thread.PushCallStackFrame(newFrame);
             if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook)
             if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook)
@@ -1766,8 +1800,7 @@ public static partial class LuaVirtualMachine
             vc.TryGetMetamethod(thread.State, name, out metamethod))
             vc.TryGetMetamethod(thread.State, name, out metamethod))
         {
         {
             var stack = thread.Stack;
             var stack = thread.Stack;
-            var top = stack.Count;
-            var argCount = 2;
+            var newBase = stack.Count;
             var callable = metamethod;
             var callable = metamethod;
             if (!metamethod.TryReadFunction(out var func))
             if (!metamethod.TryReadFunction(out var func))
             {
             {
@@ -1775,7 +1808,6 @@ public static partial class LuaVirtualMachine
                     metamethod.TryReadFunction(out func))
                     metamethod.TryReadFunction(out func))
                 {
                 {
                     stack.Push(callable);
                     stack.Push(callable);
-                    argCount++;
                 }
                 }
                 else
                 else
                 {
                 {
@@ -1785,12 +1817,13 @@ public static partial class LuaVirtualMachine
 
 
             stack.Push(vb);
             stack.Push(vb);
             stack.Push(vc);
             stack.Push(vc);
-            var varArgCount = func.GetVariableArgumentCount(argCount);
+            var (argCount, variableArgumentCount) = PrepareForFunctionCall(thread, func, newBase);
+            newBase += variableArgumentCount;
 
 
-            var newFrame = new CallStackFrame() { Base = thread.Stack.Count - argCount + varArgCount, VariableArgumentCount = varArgCount, Function = func, ReturnBase = top };
+            var newFrame = new CallStackFrame() { Base = newBase, VariableArgumentCount = variableArgumentCount, Function = func, ReturnBase = newBase };
 
 
             thread.PushCallStackFrame(newFrame);
             thread.PushCallStackFrame(newFrame);
-            var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = argCount, ReturnFrameBase = top };
+            var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = argCount, ReturnFrameBase = newBase };
             if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
             if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
             {
             {
                 await ExecuteCallHook(functionContext, ct);
                 await ExecuteCallHook(functionContext, ct);
@@ -1818,7 +1851,7 @@ public static partial class LuaVirtualMachine
         var stack = context.Stack;
         var stack = context.Stack;
         if (vb.TryGetMetamethod(context.State, name, out var metamethod))
         if (vb.TryGetMetamethod(context.State, name, out var metamethod))
         {
         {
-            var argCount = 2;
+            var newBase = stack.Count;
             var callable = metamethod;
             var callable = metamethod;
             if (!metamethod.TryReadFunction(out var func))
             if (!metamethod.TryReadFunction(out var func))
             {
             {
@@ -1826,7 +1859,6 @@ public static partial class LuaVirtualMachine
                     metamethod.TryReadFunction(out func))
                     metamethod.TryReadFunction(out func))
                 {
                 {
                     stack.Push(callable);
                     stack.Push(callable);
-                    argCount++;
                 }
                 }
                 else
                 else
                 {
                 {
@@ -1837,15 +1869,16 @@ public static partial class LuaVirtualMachine
 
 
             stack.Push(vb);
             stack.Push(vb);
             stack.Push(vb);
             stack.Push(vb);
-            var varArgCount = func.GetVariableArgumentCount(argCount);
+            var (argCount, variableArgumentCount) = PrepareForFunctionCall(context.Thread, func, newBase);
+            newBase += variableArgumentCount;
 
 
-            var newFrame = func.CreateNewFrame(context, stack.Count - argCount + varArgCount, context.FrameBase + context.Instruction.A, varArgCount);
+            var newFrame = func.CreateNewFrame(context, newBase, context.FrameBase + context.Instruction.A, variableArgumentCount);
 
 
             context.Thread.PushCallStackFrame(newFrame);
             context.Thread.PushCallStackFrame(newFrame);
             if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook)
             if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook)
             {
             {
                 context.PostOperation = PostOperationType.SetResult;
                 context.PostOperation = PostOperationType.SetResult;
-                context.Task = ExecuteCallHook(context, newFrame, 1);
+                context.Task = ExecuteCallHook(context, newFrame, argCount);
                 doRestart = false;
                 doRestart = false;
                 return false;
                 return false;
             }
             }
@@ -1858,7 +1891,7 @@ public static partial class LuaVirtualMachine
             }
             }
 
 
 
 
-            var task = func.Invoke(context, newFrame, 1);
+            var task = func.Invoke(context, newFrame, argCount);
 
 
             if (!task.IsCompleted)
             if (!task.IsCompleted)
             {
             {
@@ -1894,8 +1927,7 @@ public static partial class LuaVirtualMachine
         if (vb.TryGetMetamethod(thread.State, name, out var metamethod))
         if (vb.TryGetMetamethod(thread.State, name, out var metamethod))
         {
         {
             var stack = thread.Stack;
             var stack = thread.Stack;
-            var top = stack.Count;
-            var argCount = 2;
+            var newBase = stack.Count;
             var callable = metamethod;
             var callable = metamethod;
             if (!metamethod.TryReadFunction(out var func))
             if (!metamethod.TryReadFunction(out var func))
             {
             {
@@ -1903,7 +1935,6 @@ public static partial class LuaVirtualMachine
                     metamethod.TryReadFunction(out func))
                     metamethod.TryReadFunction(out func))
                 {
                 {
                     stack.Push(callable);
                     stack.Push(callable);
-                    argCount++;
                 }
                 }
                 else
                 else
                 {
                 {
@@ -1913,12 +1944,12 @@ public static partial class LuaVirtualMachine
 
 
             stack.Push(vb);
             stack.Push(vb);
             stack.Push(vb);
             stack.Push(vb);
-            var varArgCount = func.GetVariableArgumentCount(argCount);
-
-            var newFrame = new CallStackFrame() { Base = thread.Stack.Count - argCount + varArgCount, VariableArgumentCount = varArgCount, Function = func, ReturnBase = top };
+            var (argCount, variableArgumentCount) = PrepareForFunctionCall(thread, func, newBase);
+            newBase += variableArgumentCount;
+            var newFrame = new CallStackFrame() { Base = newBase, VariableArgumentCount = variableArgumentCount, Function = func, ReturnBase = newBase };
 
 
             thread.PushCallStackFrame(newFrame);
             thread.PushCallStackFrame(newFrame);
-            var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = argCount, ReturnFrameBase = top };
+            var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = argCount, ReturnFrameBase = newBase };
             if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
             if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
             {
             {
                 await ExecuteCallHook(functionContext, ct);
                 await ExecuteCallHook(functionContext, ct);
@@ -2047,8 +2078,7 @@ public static partial class LuaVirtualMachine
             vc.TryGetMetamethod(thread.State, name, out metamethod))
             vc.TryGetMetamethod(thread.State, name, out metamethod))
         {
         {
             var stack = thread.Stack;
             var stack = thread.Stack;
-            var top = stack.Count;
-            var argCount = 2;
+            var newBase = stack.Count;
             var callable = metamethod;
             var callable = metamethod;
             if (!metamethod.TryReadFunction(out var func))
             if (!metamethod.TryReadFunction(out var func))
             {
             {
@@ -2056,7 +2086,6 @@ public static partial class LuaVirtualMachine
                     metamethod.TryReadFunction(out func))
                     metamethod.TryReadFunction(out func))
                 {
                 {
                     stack.Push(callable);
                     stack.Push(callable);
-                    argCount++;
                 }
                 }
                 else
                 else
                 {
                 {
@@ -2066,12 +2095,12 @@ public static partial class LuaVirtualMachine
 
 
             stack.Push(vb);
             stack.Push(vb);
             stack.Push(vc);
             stack.Push(vc);
-            var varArgCount = func.GetVariableArgumentCount(argCount);
-
-            var newFrame = new CallStackFrame() { Base = thread.Stack.Count - argCount + varArgCount, VariableArgumentCount = varArgCount, Function = func, ReturnBase = top };
+            var (argCount, variableArgumentCount) = PrepareForFunctionCall(thread, func, newBase);
+            newBase += variableArgumentCount;
+            var newFrame = new CallStackFrame() { Base = newBase, VariableArgumentCount = variableArgumentCount, Function = func, ReturnBase = newBase };
 
 
             thread.PushCallStackFrame(newFrame);
             thread.PushCallStackFrame(newFrame);
-            var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = argCount, ReturnFrameBase = top };
+            var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = argCount, ReturnFrameBase = newBase };
             if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
             if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
             {
             {
                 await ExecuteCallHook(functionContext, ct);
                 await ExecuteCallHook(functionContext, ct);
@@ -2180,6 +2209,29 @@ public static partial class LuaVirtualMachine
         return PrepareVariableArgument(thread.Stack, newBase, argumentCount, variableArgumentCount);
         return PrepareVariableArgument(thread.Stack, newBase, argumentCount, variableArgumentCount);
     }
     }
 
 
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    static (int ArgumentCount, int VariableArgumentCount) PrepareForFunctionCall(LuaThread thread, LuaFunction function,
+        int newBase)
+    {
+        var argumentCount = (int)(thread.Stack.Count - newBase);
+
+        var variableArgumentCount = function.GetVariableArgumentCount(argumentCount);
+
+        if (variableArgumentCount < 0)
+        {
+            thread.Stack.SetTop(thread.Stack.Count - variableArgumentCount);
+            argumentCount -= variableArgumentCount;
+            variableArgumentCount = 0;
+        }
+
+        if (variableArgumentCount == 0)
+        {
+            return (argumentCount, 0);
+        }
+
+        return PrepareVariableArgument(thread.Stack, newBase, argumentCount, variableArgumentCount);
+    }
+
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
     static (int ArgumentCount, int VariableArgumentCount) PrepareForFunctionTailCall(LuaThread thread, LuaFunction function,
     static (int ArgumentCount, int VariableArgumentCount) PrepareForFunctionTailCall(LuaThread thread, LuaFunction function,
         Instruction instruction, int newBase, bool isMetamethod)
         Instruction instruction, int newBase, bool isMetamethod)

+ 49 - 3
tests/Lua.Tests/LuaApiTests.cs

@@ -222,11 +222,53 @@ return a,b,c
                        
                        
                        })
                        })
                      local c ={name ="c"}
                      local c ={name ="c"}
-                     return b,c
+                     return a,b,c
                      """;
                      """;
         var result = await state.DoStringAsync(source);
         var result = await state.DoStringAsync(source);
-        var b = result[0];
-        var c = result[1];
+        var a = result[0];
+        var b = result[1];
+        var c = result[2];
+        var d = await state.MainThread.OpArithmetic(b, c, OpCode.Add);
+        Assert.True(d.TryRead(out string s));
+        Assert.That(s, Is.EqualTo("abc"));
+        d = await state.MainThread.OpUnary(b, OpCode.Unm);
+        Assert.True(d.TryRead(out s));
+        Assert.That(s, Is.EqualTo("abb"));
+        d = await state.MainThread.OpConcat([c, b]);
+        Assert.True(d.TryRead(out s));
+        Assert.That(s, Is.EqualTo("acb"));
+
+        var aResult = await state.MainThread.OpCall(a, [b, c]);
+        Assert.That(aResult, Has.Length.EqualTo(1));
+        Assert.That(aResult[0].Read<string>(), Is.EqualTo("abc"));
+    }
+    [Test]
+    public async Task Test_Metamethod_MetaCallViaMeta_VarArg()
+    {
+        var source = """
+                     local a = {name ="a"}
+                     setmetatable(a, {
+                         __call = function(a, ...)
+                            local args = {...}
+                            local b,c =args[1],args[2]
+                            return a.name..b.name..c.name
+                         end
+                     })
+
+
+                     local b = setmetatable({name="b"},
+                       {__unm = a,
+                       __add= a,
+                       __concat =a
+                       
+                       })
+                     local c ={name ="c"}
+                     return a,b,c
+                     """;
+        var result = await state.DoStringAsync(source);
+        var a = result[0];
+        var b = result[1];
+        var c = result[2];
         var d = await state.MainThread.OpArithmetic(b, c, OpCode.Add);
         var d = await state.MainThread.OpArithmetic(b, c, OpCode.Add);
         Assert.True(d.TryRead(out string s));
         Assert.True(d.TryRead(out string s));
         Assert.That(s, Is.EqualTo("abc"));
         Assert.That(s, Is.EqualTo("abc"));
@@ -236,5 +278,9 @@ return a,b,c
         d = await state.MainThread.OpConcat([c, b]);
         d = await state.MainThread.OpConcat([c, b]);
         Assert.True(d.TryRead(out s));
         Assert.True(d.TryRead(out s));
         Assert.That(s, Is.EqualTo("acb"));
         Assert.That(s, Is.EqualTo("acb"));
+
+        var aResult = await state.MainThread.OpCall(a, [b, c]);
+        Assert.That(aResult, Has.Length.EqualTo(1));
+        Assert.That(aResult[0].Read<string>(), Is.EqualTo("abc"));
     }
     }
 }
 }