Browse Source

Change: enhance traceback handling for TailCall

Akeit0 9 months ago
parent
commit
f9655ad9b7
3 changed files with 130 additions and 81 deletions
  1. 5 2
      src/Lua/Runtime/CallStackFrame.cs
  2. 69 63
      src/Lua/Runtime/LuaVirtualMachine.cs
  3. 56 16
      src/Lua/Runtime/Tracebacks.cs

+ 5 - 2
src/Lua/Runtime/CallStackFrame.cs

@@ -10,11 +10,14 @@ public record struct CallStackFrame
     public required int VariableArgumentCount;
     public int CallerInstructionIndex;
     internal CallStackFrameFlags Flags;
+    internal bool IsTailCall => (Flags & CallStackFrameFlags.TailCall) ==CallStackFrameFlags.TailCall;
 }
 
 [Flags]
 public enum CallStackFrameFlags
 {
-    ReversedLe = 1,
-    TailCall
+    //None = 0,
+    ReversedLe  = 1,
+    TailCall = 2,
+    InHook = 4,
 }

+ 69 - 63
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -39,6 +39,8 @@ public static partial class LuaVirtualMachine
 
         readonly int BaseCallStackCount = thread.CallStack.Count;
 
+        public PostOperationType PostOperation;
+
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
         public bool Pop(Instruction instruction, int frameBase)
         {
@@ -186,49 +188,60 @@ public static partial class LuaVirtualMachine
             ResultsBuffer.AsSpan(0, count).Clear();
         }
 
+        public int? ExecutePostOperation(PostOperationType postOperation)
+        {
+            switch (postOperation)
+            {
+                case PostOperationType.Nop: break;
+                case PostOperationType.SetResult:
+                    var RA = Instruction.A + FrameBase;
+                    Stack.Get(RA) = TaskResult == 0 ? LuaValue.Nil : ResultsBuffer[0];
+                    Stack.NotifyTop(RA + 1);
+                    ClearResultsBuffer();
+                    break;
+                case PostOperationType.TForCall:
+                    TForCallPostOperation(ref this);
+                    break;
+                case PostOperationType.Call:
+                    CallPostOperation(ref this);
+                    break;
+                case PostOperationType.TailCall:
+                    var resultsSpan = ResultsBuffer.AsSpan(0, TaskResult);
+                    if (!PopFromBuffer(resultsSpan))
+                    {
+                        ResultCount = TaskResult;
+                        resultsSpan.CopyTo(Buffer.Span);
+                        resultsSpan.Clear();
+                        LuaValueArrayPool.Return1024(ResultsBuffer);
+                        return TaskResult;
+                    }
+
+                    resultsSpan.Clear();
+                    break;
+                case PostOperationType.Self:
+                    SelfPostOperation(ref this);
+                    break;
+                case PostOperationType.Compare:
+                    ComparePostOperation(ref this);
+                    break;
+            }
+
+            return null;
+        }
+
         public async ValueTask<int> ExecuteClosureAsyncImpl()
         {
-            while (MoveNext(ref this, out var postOperation))
+            while (MoveNext(ref this))
             {
                 TaskResult = await Task;
                 Task = default;
-
-                Thread.PopCallStackFrame();
-                switch (postOperation)
+                if (PostOperation != PostOperationType.TailCall)
                 {
-                    case PostOperationType.Nop: break;
-                    case PostOperationType.SetResult:
-                        var RA = Instruction.A + FrameBase;
-                        Stack.Get(RA) = TaskResult == 0 ? LuaValue.Nil : ResultsBuffer[0];
-                        Stack.NotifyTop(RA + 1);
-                        ClearResultsBuffer();
-                        break;
-                    case PostOperationType.TForCall:
-                        TForCallPostOperation(ref this);
-                        break;
-                    case PostOperationType.Call:
-                        CallPostOperation(ref this);
-                        break;
-                    case PostOperationType.TailCall:
-                        var resultsSpan = ResultsBuffer.AsSpan(0, TaskResult);
-                        if (!PopFromBuffer(resultsSpan))
-                        {
-                            ResultCount = TaskResult;
-                            resultsSpan.CopyTo(Buffer.Span);
-                            resultsSpan.Clear();
-                            LuaValueArrayPool.Return1024(ResultsBuffer);
-                            return TaskResult;
-                        }
-
-                        resultsSpan.Clear();
-                        break;
-                    case PostOperationType.Self:
-                        SelfPostOperation(ref this);
-                        break;
-                    case PostOperationType.Compare:
-                        ComparePostOperation(ref this);
-                        break;
+                    Thread.PopCallStackFrame();
                 }
+
+                var r = ExecutePostOperation(PostOperation);
+                if (r.HasValue) return r.Value;
             }
 
             return ResultCount;
@@ -259,20 +272,17 @@ public static partial class LuaVirtualMachine
         return context.ExecuteClosureAsyncImpl();
     }
 
-    static bool MoveNext(ref VirtualMachineExecutionContext context, out PostOperationType postOperation)
+    static bool MoveNext(ref VirtualMachineExecutionContext context)
     {
-        postOperation = PostOperationType.None;
-
         try
         {
-        // This is a label to restart the execution when new function is called or restarted
+            // This is a label to restart the execution when new function is called or restarted
         Restart:
             ref var instructionsHead = ref context.Chunk.Instructions[0];
             var frameBase = context.FrameBase;
             var stack = context.Stack;
             stack.EnsureCapacity(frameBase + context.Chunk.MaxStackPosition);
             ref var constHead = ref MemoryMarshalEx.UnsafeElementAt(context.Chunk.Constants, 0);
-
             while (true)
             {
                 var instructionRef = Unsafe.Add(ref instructionsHead, ++context.Pc);
@@ -320,7 +330,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.SetResult;
                         return true;
                     case OpCode.SetTabUp:
                         instruction = instructionRef;
@@ -354,7 +363,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.Nop;
                         return true;
 
                     case OpCode.SetUpVal:
@@ -394,7 +402,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.Nop;
                         return true;
                     case OpCode.NewTable:
                         instruction = instructionRef;
@@ -417,7 +424,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.Self;
                         return true;
                     case OpCode.Add:
                         instruction = instructionRef;
@@ -445,7 +451,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.SetResult;
                         return true;
                     case OpCode.Sub:
                         instruction = instructionRef;
@@ -476,7 +481,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.SetResult;
                         return true;
 
                     case OpCode.Mul:
@@ -508,7 +512,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.SetResult;
                         return true;
 
                     case OpCode.Div:
@@ -540,7 +543,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.SetResult;
                         return true;
                     case OpCode.Mod:
                         instruction = instructionRef;
@@ -566,7 +568,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.SetResult;
                         return true;
                     case OpCode.Pow:
                         instruction = instructionRef;
@@ -588,7 +589,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.SetResult;
                         return true;
                     case OpCode.Unm:
                         instruction = instructionRef;
@@ -610,7 +610,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.SetResult;
                         return true;
                     case OpCode.Not:
                         instruction = instructionRef;
@@ -642,7 +641,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.SetResult;
                         return true;
                     case OpCode.Concat:
                         if (Concat(ref context, out doRestart))
@@ -651,7 +649,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.SetResult;
                         return true;
                     case OpCode.Jmp:
                         instruction = instructionRef;
@@ -685,7 +682,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.Compare;
                         return true;
                     case OpCode.Lt:
                         instruction = instructionRef;
@@ -723,7 +719,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.Compare;
                         return true;
                     case OpCode.Le:
                         instruction = instructionRef;
@@ -760,7 +755,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.Compare;
                         return true;
                     case OpCode.Test:
                         instruction = instructionRef;
@@ -791,7 +785,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.Call;
                         return true;
                     case OpCode.TailCall:
                         if (TailCall(ref context, out doRestart))
@@ -801,7 +794,6 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.TailCall;
                         return true;
                     case OpCode.Return:
                         instruction = instructionRef;
@@ -872,7 +864,7 @@ public static partial class LuaVirtualMachine
                             continue;
                         }
 
-                        postOperation = PostOperationType.TForCall;
+
                         return true;
                     case OpCode.TForLoop:
                         instruction = instructionRef;
@@ -926,7 +918,7 @@ public static partial class LuaVirtualMachine
             }
 
         End:
-            postOperation = PostOperationType.None;
+            context.PostOperation = PostOperationType.None;
             LuaValueArrayPool.Return1024(context.ResultsBuffer);
             return false;
         }
@@ -1032,6 +1024,7 @@ public static partial class LuaVirtualMachine
         var newFrame = func.CreateNewFrame(ref context, newBase, variableArgumentCount);
 
         thread.PushCallStackFrame(newFrame);
+        
         if (func is Closure)
         {
             context.Push(newFrame);
@@ -1048,6 +1041,7 @@ public static partial class LuaVirtualMachine
 
             if (!task.IsCompleted)
             {
+                context.PostOperation = PostOperationType.Call;
                 context.Task = task;
                 return false;
             }
@@ -1148,7 +1142,14 @@ public static partial class LuaVirtualMachine
 
         var (newBase, argumentCount, variableArgumentCount) = PrepareForFunctionTailCall(thread, func, instruction, RA);
 
+
+        var lastPc = thread.CallStack.AsSpan()[^1].CallerInstructionIndex;
+        context.Thread.PopCallStackFrameUnsafe();
+
         var newFrame = func.CreateNewFrame(ref context, newBase, variableArgumentCount);
+
+        newFrame.Flags |= CallStackFrameFlags.TailCall;
+        newFrame.CallerInstructionIndex = lastPc;
         thread.PushCallStackFrame(newFrame);
 
         context.Push(newFrame);
@@ -1163,12 +1164,11 @@ public static partial class LuaVirtualMachine
 
         if (!task.IsCompleted)
         {
+            context.PostOperation = PostOperationType.TailCall;
             context.Task = task;
             return false;
         }
 
-        context.Thread.PopCallStackFrame();
-
         doRestart = true;
         var awaiter = task.GetAwaiter();
         var resultCount = awaiter.GetResult();
@@ -1214,6 +1214,7 @@ public static partial class LuaVirtualMachine
 
         if (!task.IsCompleted)
         {
+            context.PostOperation = PostOperationType.TForCall;
             context.Task = task;
 
             return false;
@@ -1264,6 +1265,7 @@ public static partial class LuaVirtualMachine
         table.EnsureArrayCapacity((instruction.C - 1) * 50 + count);
         stack.GetBuffer().Slice(RA + 1, count)
             .CopyTo(table.GetArraySpan()[((instruction.C - 1) * 50)..]);
+        stack.PopUntil(RA + 1);
     }
 
     static void ComparePostOperation(ref VirtualMachineExecutionContext context)
@@ -1358,6 +1360,7 @@ public static partial class LuaVirtualMachine
 
         if (!task.IsCompleted)
         {
+            context.PostOperation = context.Instruction.OpCode == OpCode.GetTable ? PostOperationType.SetResult : PostOperationType.Self;
             context.Task = task;
             result = default;
             return false;
@@ -1417,6 +1420,7 @@ public static partial class LuaVirtualMachine
         Function:
             if (table.TryReadFunction(out var function))
             {
+                context.PostOperation = PostOperationType.Nop;
                 return CallSetTableFunc(targetTable, function, key, value, ref context, out doRestart);
             }
         }
@@ -1495,6 +1499,7 @@ public static partial class LuaVirtualMachine
 
             if (!task.IsCompleted)
             {
+                context.PostOperation = PostOperationType.SetResult;
                 context.Task = task;
                 return false;
             }
@@ -1597,6 +1602,7 @@ public static partial class LuaVirtualMachine
 
             if (!task.IsCompleted)
             {
+                context.PostOperation = PostOperationType.Compare;
                 context.Task = task;
                 return false;
             }

+ 56 - 16
src/Lua/Runtime/Tracebacks.cs

@@ -22,24 +22,30 @@ public class Traceback(LuaState state)
             {
                 LuaFunction lastFunc = index > 0 ? stackFrames[index - 1].Function : RootFunc;
                 var frame = stackFrames[index];
-                if (lastFunc is Closure closure)
+                if (!frame.IsTailCall && lastFunc is Closure closure)
                 {
                     var p = closure.Proto;
+                    if (frame.CallerInstructionIndex < 0 || p.SourcePositions.Length <= frame.CallerInstructionIndex)
+                    {
+                        Console.WriteLine($"Trace back error");
+                        return default;
+                    }
+
                     return p.SourcePositions[frame.CallerInstructionIndex];
                 }
             }
 
+
             return default;
         }
     }
 
-
     public override string ToString()
     {
-        return GetTracebackString(State,RootFunc, StackFrames, LuaValue.Nil);
+        return GetTracebackString(State, RootFunc, StackFrames, LuaValue.Nil);
     }
 
-    internal static string GetTracebackString(LuaState state,Closure rootFunc, ReadOnlySpan<CallStackFrame> stackFrames, LuaValue message)
+    internal static string GetTracebackString(LuaState state, Closure rootFunc, ReadOnlySpan<CallStackFrame> stackFrames, LuaValue message, bool skipFirstCsharpCall = false)
     {
         using var list = new PooledList<char>(64);
         if (message.Type is not LuaValueType.Nil)
@@ -51,6 +57,15 @@ public class Traceback(LuaState state)
         list.AddRange("stack traceback:\n");
         var intFormatBuffer = (stackalloc char[15]);
         var shortSourceBuffer = (stackalloc char[59]);
+        {
+            if (0 < stackFrames.Length && !skipFirstCsharpCall && stackFrames[^1].Function is { } f and not Closure)
+            {
+                list.AddRange("\t[C#]: in function '");
+                list.AddRange(f.Name);
+                list.AddRange("'\n");
+            }
+        }
+
         for (var index = stackFrames.Length - 1; index >= 0; index--)
         {
             LuaFunction lastFunc = index > 0 ? stackFrames[index - 1].Function : rootFunc;
@@ -63,27 +78,50 @@ public class Traceback(LuaState state)
             else if (lastFunc is Closure closure)
             {
                 var frame = stackFrames[index];
+
+                if (frame.IsTailCall)
+                {
+                    list.AddRange("\t(...tail calls...)\n");
+                }
+
                 var p = closure.Proto;
                 var root = p.GetRoot();
                 list.AddRange("\t");
                 var len = LuaDebug.WriteShortSource(root.Name, shortSourceBuffer);
                 list.AddRange(shortSourceBuffer[..len]);
                 list.AddRange(":");
-                p.SourcePositions[frame.CallerInstructionIndex].Line.TryFormat(intFormatBuffer, out var charsWritten, provider: CultureInfo.InvariantCulture);
-                list.AddRange(intFormatBuffer[..charsWritten]);
+                if (p.SourcePositions.Length <= frame.CallerInstructionIndex)
+                {
+                    list.AddRange("Trace back error");
+                }
+                else
+                {
+                    p.SourcePositions[frame.CallerInstructionIndex].Line.TryFormat(intFormatBuffer, out var charsWritten, provider: CultureInfo.InvariantCulture);
+                    list.AddRange(intFormatBuffer[..charsWritten]);
+                }
+
 
                 list.AddRange(": in ");
                 if (root == p)
                 {
                     list.AddRange("main chunk");
                     list.AddRange("\n");
-                    continue;
+                    goto Next;
+                }
+
+                if (0 < index && stackFrames[index - 1].Flags.HasFlag(CallStackFrameFlags.InHook))
+                {
+                    list.AddRange("hook");
+                    list.AddRange(" '");
+                    list.AddRange("?");
+                    list.AddRange("'\n");
+                    goto Next;
                 }
 
                 foreach (var pair in state.Environment.Dictionary)
                 {
-                    if (pair.Key.TryReadString(out var name) 
-                        && pair.Value.TryReadFunction(out var result) && 
+                    if (pair.Key.TryReadString(out var name)
+                        && pair.Value.TryReadFunction(out var result) &&
                         result == closure)
                     {
                         list.AddRange("function '");
@@ -113,19 +151,21 @@ public class Traceback(LuaState state)
                             list.AddRange("'\n");
                         }
 
-                        continue;
+                        goto Next;
                     }
                 }
 
-                
+
                 list.AddRange("function <");
                 list.AddRange(shortSourceBuffer[..len]);
                 list.AddRange(":");
-                p.LineDefined.TryFormat(intFormatBuffer, out charsWritten, provider: CultureInfo.InvariantCulture);
-                list.AddRange(intFormatBuffer[..charsWritten]);
-                list.AddRange(">\n");
-                
-                Next: ;
+                {
+                    p.LineDefined.TryFormat(intFormatBuffer, out var charsWritten, provider: CultureInfo.InvariantCulture);
+                    list.AddRange(intFormatBuffer[..charsWritten]);
+                    list.AddRange(">\n");
+                }
+
+            Next: ;
             }
         }