Browse Source

Fix: hook metamethods

Akeit0 9 months ago
parent
commit
3c2be936ff

+ 7 - 0
src/Lua/LuaFunction.cs

@@ -22,8 +22,15 @@ public class LuaFunction(string name, Func<LuaFunctionExecutionContext, Memory<L
         };
 
         context.Thread.PushCallStackFrame(frame);
+
+
         try
         {
+            if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook)
+            {
+                return await LuaVirtualMachine.ExecuteCallHook(context, buffer, cancellationToken);
+            }
+
             return await Func(context, buffer, cancellationToken);
         }
         finally

+ 87 - 79
src/Lua/Runtime/LuaVirtualMachine.Debug.cs

@@ -113,99 +113,107 @@ public static partial class LuaVirtualMachine
         }
     }
 
-    static void ExecuteCallHook(ref VirtualMachineExecutionContext context, int argCount, bool isTailCall = false)
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static ValueTask<int> ExecuteCallHook(ref VirtualMachineExecutionContext context, in CallStackFrame frame, int arguments, bool isTailCall = false)
     {
-        context.Task = Impl(context, argCount, isTailCall);
-
-        static async ValueTask<int> Impl(VirtualMachineExecutionContext context, int argCount, bool isTailCall)
+        return ExecuteCallHook(new()
         {
-            var topFrame = context.Thread.GetCurrentFrame();
-            var hook = context.Thread.Hook!;
-            var stack = context.Thread.Stack;
-            CallStackFrame frame;
-            if (context.Thread.IsCallHookEnabled)
-            {
-                stack.Push((isTailCall ? "tail call" : "call"));
+            State = context.State,
+            Thread = context.Thread,
+            ArgumentCount = arguments,
+            FrameBase = frame.Base,
+            CallerInstructionIndex = frame.CallerInstructionIndex,
+        }, context.ResultsBuffer, context.CancellationToken, isTailCall);
+    }
 
-                stack.Push(LuaValue.Nil);
-                var funcContext = new LuaFunctionExecutionContext
-                {
-                    State = context.State,
-                    Thread = context.Thread,
-                    ArgumentCount = 2,
-                    FrameBase = context.Thread.Stack.Count - 2,
-                };
-                frame = new()
-                {
-                    Base = funcContext.FrameBase,
-                    VariableArgumentCount = hook.GetVariableArgumentCount(2),
-                    Function = hook,
-                    CallerInstructionIndex = 0,
-                };
-                frame.Flags |= CallStackFrameFlags.InHook;
+    internal static async ValueTask<int> ExecuteCallHook(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken, bool isTailCall = false)
+    {
+        var topFrame = context.Thread.GetCurrentFrame();
+        var argCount = context.ArgumentCount;
+        var hook = context.Thread.Hook!;
+        var stack = context.Thread.Stack;
+        CallStackFrame frame;
+        if (context.Thread.IsCallHookEnabled)
+        {
+            stack.Push((isTailCall ? "tail call" : "call"));
 
-                context.Thread.PushCallStackFrame(frame);
-                try
-                {
-                    context.Thread.IsInHook = true;
-                    await hook.Func(funcContext, Memory<LuaValue>.Empty, context.CancellationToken);
-                }
-                finally
-                {
-                    context.Thread.IsInHook = false;
-                    context.Thread.PopCallStackFrame();
-                }
+            stack.Push(LuaValue.Nil);
+            var funcContext = new LuaFunctionExecutionContext
+            {
+                State = context.State,
+                Thread = context.Thread,
+                ArgumentCount = 2,
+                FrameBase = context.Thread.Stack.Count - 2,
+            };
+            frame = new()
+            {
+                Base = funcContext.FrameBase,
+                VariableArgumentCount = hook.GetVariableArgumentCount(2),
+                Function = hook,
+                CallerInstructionIndex = 0,
+            };
+            frame.Flags |= CallStackFrameFlags.InHook;
+
+            context.Thread.PushCallStackFrame(frame);
+            try
+            {
+                context.Thread.IsInHook = true;
+                await hook.Func(funcContext, Memory<LuaValue>.Empty, cancellationToken);
+            }
+            finally
+            {
+                context.Thread.IsInHook = false;
+                context.Thread.PopCallStackFrame();
             }
+        }
 
-            frame = topFrame;
-            context.Push(frame);
+        frame = topFrame;
 
-            var task = frame.Function.Func(new ()
+        var task = frame.Function.Func(new()
+        {
+            State = context.State,
+            Thread = context.Thread,
+            ArgumentCount = argCount,
+            FrameBase = frame.Base,
+        }, buffer, cancellationToken);
+        if (isTailCall || !context.Thread.IsReturnHookEnabled)
+        {
+            return await task;
+        }
+
+        {
+            var result = await task;
+            stack.Push("return");
+            stack.Push(LuaValue.Nil);
+            var funcContext = new LuaFunctionExecutionContext
             {
                 State = context.State,
                 Thread = context.Thread,
-                ArgumentCount = argCount,
-                FrameBase = frame.Base,
-            }, context.ResultsBuffer, context.CancellationToken);
-            if (isTailCall || !context.Thread.IsReturnHookEnabled)
+                ArgumentCount = 2,
+                FrameBase = context.Thread.Stack.Count - 2,
+            };
+            frame = new()
             {
-                return await task;
+                Base = funcContext.FrameBase,
+                VariableArgumentCount = hook.GetVariableArgumentCount(2),
+                Function = hook,
+                CallerInstructionIndex = 0
+            };
+            frame.Flags |= CallStackFrameFlags.InHook;
+
+            context.Thread.PushCallStackFrame(frame);
+            try
+            {
+                context.Thread.IsInHook = true;
+                await hook.Func(funcContext, Memory<LuaValue>.Empty, cancellationToken);
             }
-
+            finally
             {
-                var result = await task;
-                stack.Push("return");
-                stack.Push(LuaValue.Nil);
-                var funcContext = new LuaFunctionExecutionContext
-                {
-                    State = context.State,
-                    Thread = context.Thread,
-                    ArgumentCount = 2,
-                    FrameBase = context.Thread.Stack.Count - 2,
-                };
-                frame = new()
-                {
-                    Base = funcContext.FrameBase,
-                    VariableArgumentCount = hook.GetVariableArgumentCount(2),
-                    Function = hook,
-                    CallerInstructionIndex = 0
-                };
-                frame.Flags |= CallStackFrameFlags.InHook;
-
-                context.Thread.PushCallStackFrame(frame);
-                try
-                {
-                    context.Thread.IsInHook = true;
-                    await hook.Func(funcContext, Memory<LuaValue>.Empty, context.CancellationToken);
-                }
-                finally
-                {
-                    context.Thread.IsInHook = false;
-                }
-
-                context.Thread.PopCallStackFrame();
-                return result;
+                context.Thread.IsInHook = false;
             }
+
+            context.Thread.PopCallStackFrame();
+            return result;
         }
     }
 }

+ 43 - 6
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -308,7 +308,7 @@ public static partial class LuaVirtualMachine
             {
                 var instructionRef = Unsafe.Add(ref instructionsHead, ++context.Pc);
                 context.Instruction = instructionRef;
-                if (lineAndCountHookMask.Value!=0 && (context.Pc != context.LastHookPc))
+                if (lineAndCountHookMask.Value != 0 && (context.Pc != context.LastHookPc))
                 {
                     goto LineHook;
                 }
@@ -1052,10 +1052,10 @@ public static partial class LuaVirtualMachine
         var newFrame = func.CreateNewFrame(ref context, newBase, variableArgumentCount);
 
         thread.PushCallStackFrame(newFrame);
-        if (thread.CallOrReturnHookMask.Value!=0 && !context.Thread.IsInHook)
+        if (thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook)
         {
             context.PostOperation = PostOperationType.Call;
-            ExecuteCallHook(ref context, argumentCount);
+            context.Task=ExecuteCallHook(ref context, newFrame,argumentCount);
             doRestart = false;
             return false;
         }
@@ -1187,10 +1187,10 @@ public static partial class LuaVirtualMachine
         newFrame.CallerInstructionIndex = lastPc;
         thread.PushCallStackFrame(newFrame);
 
-        if (thread.CallOrReturnHookMask.Value!=0 && !context.Thread.IsInHook)
+        if (thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook)
         {
             context.PostOperation = PostOperationType.TailCall;
-            ExecuteCallHook(ref context, argumentCount, true);
+           context.Task=ExecuteCallHook(ref context, newFrame,argumentCount,true);
             doRestart = false;
             return false;
         }
@@ -1390,6 +1390,14 @@ public static partial class LuaVirtualMachine
         var newFrame = indexTable.CreateNewFrame(ref context, stack.Count - 2);
 
         context.Thread.PushCallStackFrame(newFrame);
+        if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook)
+        {
+            context.PostOperation = context.Instruction.OpCode == OpCode.GetTable ? PostOperationType.SetResult : PostOperationType.Self;
+           context.Task=ExecuteCallHook(ref context, newFrame,2);
+            doRestart = false;
+            result = default;
+            return false;
+        }
 
         if (indexTable is Closure)
         {
@@ -1483,6 +1491,13 @@ public static partial class LuaVirtualMachine
         var newFrame = newIndexFunction.CreateNewFrame(ref context, stack.Count - 3);
 
         context.Thread.PushCallStackFrame(newFrame);
+        if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook)
+        {
+            context.PostOperation = PostOperationType.Nop;
+             context.Task=ExecuteCallHook(ref context, newFrame,3);
+            doRestart = false;
+            return false;
+        }
 
         if (newIndexFunction is Closure)
         {
@@ -1494,6 +1509,7 @@ public static partial class LuaVirtualMachine
         var task = newIndexFunction.Invoke(ref context, newFrame, 3);
         if (!task.IsCompleted)
         {
+            context.PostOperation = PostOperationType.Nop;
             context.Task = task;
             return false;
         }
@@ -1529,6 +1545,13 @@ public static partial class LuaVirtualMachine
             var newFrame = func.CreateNewFrame(ref context, stack.Count - 2);
 
             context.Thread.PushCallStackFrame(newFrame);
+            if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook)
+            {
+                context.PostOperation = PostOperationType.SetResult;
+                context.Task=ExecuteCallHook(ref context, newFrame,2);
+                doRestart = false;
+                return false;
+            }
 
             if (func is Closure)
             {
@@ -1576,6 +1599,13 @@ public static partial class LuaVirtualMachine
             var newFrame = func.CreateNewFrame(ref context, stack.Count - 1);
 
             context.Thread.PushCallStackFrame(newFrame);
+            if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook)
+            {
+                context.PostOperation = PostOperationType.SetResult;
+              context.Task=ExecuteCallHook(ref context, newFrame,1);
+                doRestart = false;
+                return false;
+            }
 
             if (func is Closure)
             {
@@ -1589,6 +1619,7 @@ public static partial class LuaVirtualMachine
 
             if (!task.IsCompleted)
             {
+                context.PostOperation = PostOperationType.SetResult;
                 context.Task = task;
                 return false;
             }
@@ -1633,7 +1664,13 @@ public static partial class LuaVirtualMachine
             var newFrame = func.CreateNewFrame(ref context, stack.Count - 2);
             if (reverseLe) newFrame.Flags |= CallStackFrameFlags.ReversedLe;
             context.Thread.PushCallStackFrame(newFrame);
-
+            if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook)
+            {
+                context.PostOperation = PostOperationType.Compare;
+                context.Task=ExecuteCallHook(ref context, newFrame,2);
+                doRestart = false;
+                return false;
+            }
             if (func is Closure)
             {
                 context.Push(newFrame);

+ 17 - 1
tests/Lua.Tests/MetatableTests.cs

@@ -10,7 +10,7 @@ public class MetatableTests
     public void SetUp()
     {
         state = LuaState.Create();
-        state.OpenBasicLibrary();
+        state.OpenStandardLibraries();
     }
 
     [Test]
@@ -86,4 +86,20 @@ assert(metatable.__newindex.x == 2)
 ";
         await state.DoStringAsync(source);
     }
+
+    [Test]
+    public async Task Test_Hook_Metamethods()
+    {
+        var source = """ 
+                     local t = {}
+                     local a =setmetatable({},{__add =function (a,b) return a end})
+
+                     debug.sethook(function () table.insert(t,debug.traceback()) end,"c")
+                     a =a+a
+                     return t
+                     """;
+        var r = await state.DoStringAsync(source);
+        Assert.That(r, Has.Length.EqualTo(1));
+        Assert.That(r[0].Read<LuaTable>()[1].Read<string>(), Does.Contain("stack traceback:"));
+    }
 }