Bladeren bron

Merge pull request #147 from nuskey8/improve-cancel

Improved cancellation process and avoidance of infinite loops
Akeit0 7 maanden geleden
bovenliggende
commit
d36e9b8dcf

+ 58 - 6
src/Lua/Exceptions.cs

@@ -2,6 +2,7 @@ using Lua.CodeAnalysis;
 using Lua.CodeAnalysis.Syntax;
 using Lua.CodeAnalysis.Syntax;
 using Lua.Internal;
 using Lua.Internal;
 using Lua.Runtime;
 using Lua.Runtime;
+using System.Diagnostics;
 using System.Runtime.CompilerServices;
 using System.Runtime.CompilerServices;
 
 
 namespace Lua;
 namespace Lua;
@@ -67,9 +68,14 @@ public class LuaCompileException(string chunkName, SourcePosition position, int
 
 
 public class LuaUnDumpException(string message) : Exception(message);
 public class LuaUnDumpException(string message) : Exception(message);
 
 
-public class LuaRuntimeException : Exception
+internal interface ILuaTracebackBuildable
 {
 {
-    public LuaRuntimeException(LuaThread? thread, Exception innerException) : base(innerException.Message,innerException)
+    Traceback? BuildOrGet();
+}
+
+public class LuaRuntimeException : Exception, ILuaTracebackBuildable
+{
+    public LuaRuntimeException(LuaThread? thread, Exception innerException) : base(innerException.Message, innerException)
     {
     {
         Thread = thread;
         Thread = thread;
     }
     }
@@ -78,7 +84,7 @@ public class LuaRuntimeException : Exception
     {
     {
         if (thread != null)
         if (thread != null)
         {
         {
-            thread.CurrentException?.Build();
+            thread.CurrentException?.BuildOrGet();
             thread.ExceptionTrace.Clear();
             thread.ExceptionTrace.Clear();
             thread.CurrentException = this;
             thread.CurrentException = this;
         }
         }
@@ -96,7 +102,7 @@ public class LuaRuntimeException : Exception
         {
         {
             if (luaTraceback == null)
             if (luaTraceback == null)
             {
             {
-                Build();
+                ((ILuaTracebackBuildable)this).BuildOrGet();
             }
             }
 
 
             return luaTraceback;
             return luaTraceback;
@@ -185,7 +191,7 @@ public class LuaRuntimeException : Exception
 
 
 
 
     [MethodImpl(MethodImplOptions.NoInlining)]
     [MethodImpl(MethodImplOptions.NoInlining)]
-    internal Traceback? Build()
+    Traceback? ILuaTracebackBuildable.BuildOrGet()
     {
     {
         if (luaTraceback != null) return luaTraceback;
         if (luaTraceback != null) return luaTraceback;
         if (Thread != null)
         if (Thread != null)
@@ -247,4 +253,50 @@ public class LuaRuntimeException : Exception
 
 
 public class LuaAssertionException(LuaThread? traceback, string message) : LuaRuntimeException(traceback, message);
 public class LuaAssertionException(LuaThread? traceback, string message) : LuaRuntimeException(traceback, message);
 
 
-public class LuaModuleNotFoundException(string moduleName) : Exception($"module '{moduleName}' not found");
+public class LuaModuleNotFoundException(string moduleName) : Exception($"module '{moduleName}' not found");
+
+public sealed class LuaCancelledException : OperationCanceledException, ILuaTracebackBuildable
+{
+    Traceback? luaTraceback;
+
+    public Traceback? LuaTraceback
+    {
+        get
+        {
+            if (luaTraceback == null)
+            {
+                ((ILuaTracebackBuildable)this).BuildOrGet();
+            }
+
+            return luaTraceback;
+        }
+    }
+
+    internal LuaThread? Thread { get; private set; }
+
+    internal LuaCancelledException(LuaThread thread, CancellationToken cancellationToken, Exception? innerException = null) : base("The operation was cancelled during execution on Lua.", innerException, cancellationToken)
+    {
+        thread.CurrentException?.BuildOrGet();
+        thread.ExceptionTrace.Clear();
+        thread.CurrentException = this;
+        Thread = thread;
+    }
+
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    Traceback? ILuaTracebackBuildable.BuildOrGet()
+    {
+        if (luaTraceback != null) return luaTraceback;
+
+        if (Thread != null)
+        {
+            var callStack = Thread.ExceptionTrace.AsSpan();
+            if (callStack.IsEmpty) return null;
+            luaTraceback = new Traceback(Thread.State, callStack);
+            Thread.ExceptionTrace.Clear();
+            Thread = null!;
+        }
+
+        return luaTraceback;
+    }
+}

+ 2 - 2
src/Lua/LuaCoroutine.cs

@@ -186,9 +186,9 @@ public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.Yiel
             {
             {
                 if (IsProtectedMode)
                 if (IsProtectedMode)
                 {
                 {
-                    if (ex is LuaRuntimeException luaRuntimeException)
+                    if (ex is ILuaTracebackBuildable tracebackBuildable)
                     {
                     {
-                        traceback = luaRuntimeException.Build();
+                        traceback = tracebackBuildable.BuildOrGet();
                     }
                     }
 
 
                     Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
                     Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);

+ 5 - 2
src/Lua/LuaThread.cs

@@ -70,8 +70,11 @@ public abstract class LuaThread
     internal int LastVersion;
     internal int LastVersion;
     internal int CurrentVersion;
     internal int CurrentVersion;
 
 
-    internal LuaRuntimeException? CurrentException;
+    internal ILuaTracebackBuildable? CurrentException;
     internal readonly ReversedStack<CallStackFrame> ExceptionTrace = new();
     internal readonly ReversedStack<CallStackFrame> ExceptionTrace = new();
+    
+    // internal bool CancelRequested;
+    // internal CancellationToken CancellationToken;
 
 
     public bool IsRunning => CallStackFrameCount != 0;
     public bool IsRunning => CallStackFrameCount != 0;
     internal LuaFunction? Hook { get; set; }
     internal LuaFunction? Hook { get; set; }
@@ -131,7 +134,7 @@ public abstract class LuaThread
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
     internal LuaThreadAccess PushCallStackFrame(in CallStackFrame frame)
     internal LuaThreadAccess PushCallStackFrame(in CallStackFrame frame)
     {
     {
-        CurrentException?.Build();
+        CurrentException?.BuildOrGet();
         CurrentException = null;
         CurrentException = null;
         ref var callStack = ref CoreData!.CallStack;
         ref var callStack = ref CoreData!.CallStack;
         callStack.Push(frame);
         callStack.Push(frame);

+ 15 - 0
src/Lua/LuaThreadExtensions.cs

@@ -13,4 +13,19 @@ public static class LuaThreadExtensions
     {
     {
         return new(LuaCoroutine.Create(thread, function, isProtectedMode));
         return new(LuaCoroutine.Create(thread, function, isProtectedMode));
     }
     }
+
+    internal static void ThrowIfCancellationRequested(this LuaThread thread, CancellationToken cancellationToken)
+    {
+        if (cancellationToken.IsCancellationRequested)
+        {
+            Throw(thread, cancellationToken);
+        }
+
+        return;
+
+        static void Throw(LuaThread thread, CancellationToken cancellationToken)
+        {
+            throw new LuaCancelledException(thread, cancellationToken);
+        }
+    }
 }
 }

+ 1 - 1
src/Lua/Runtime/LuaStack.cs

@@ -32,7 +32,7 @@ public sealed class LuaStack(int initialSize = 256)
 
 
             if (1000000 < size)
             if (1000000 < size)
             {
             {
-                throw new LuaException("Lua Stack overflow");
+                throw new ("Lua Stack overflow");
             }
             }
 
 
             Array.Resize(ref array, size);
             Array.Resize(ref array, size);

+ 3 - 2
src/Lua/Runtime/LuaThreadAccess.cs

@@ -53,6 +53,7 @@ public readonly struct LuaThreadAccess
             throw new ArgumentNullException(nameof(function));
             throw new ArgumentNullException(nameof(function));
         }
         }
 
 
+        Thread.ThrowIfCancellationRequested(cancellationToken);
         var thread = Thread;
         var thread = Thread;
         var varArgumentCount = function.GetVariableArgumentCount(argumentCount);
         var varArgumentCount = function.GetVariableArgumentCount(argumentCount);
         if (varArgumentCount != 0)
         if (varArgumentCount != 0)
@@ -78,7 +79,7 @@ public readonly struct LuaThreadAccess
 
 
         var access = thread.PushCallStackFrame(frame);
         var access = thread.PushCallStackFrame(frame);
         LuaFunctionExecutionContext context = new() { Access = access, ArgumentCount = argumentCount, ReturnFrameBase = returnBase, };
         LuaFunctionExecutionContext context = new() { Access = access, ArgumentCount = argumentCount, ReturnFrameBase = returnBase, };
-
+        var callStackTop = thread.CallStackFrameCount;
         try
         try
         {
         {
             if (this.Thread.CallOrReturnHookMask.Value != 0 && !this.Thread.IsInHook)
             if (this.Thread.CallOrReturnHookMask.Value != 0 && !this.Thread.IsInHook)
@@ -90,7 +91,7 @@ public readonly struct LuaThreadAccess
         }
         }
         finally
         finally
         {
         {
-            this.Thread.PopCallStackFrame();
+            this.Thread.PopCallStackFrameUntil(callStackTop-1);
         }
         }
     }
     }
 
 

+ 3 - 2
src/Lua/Runtime/LuaVirtualMachine.Debug.cs

@@ -47,7 +47,7 @@ public static partial class LuaVirtualMachine
                 countHookIsDone = true;
                 countHookIsDone = true;
             }
             }
 
 
-
+            context.ThrowIfCancellationRequested();
             if (context.Thread.IsLineHookEnabled)
             if (context.Thread.IsLineHookEnabled)
             {
             {
                 var sourcePositions = prototype.LineInfo;
                 var sourcePositions = prototype.LineInfo;
@@ -126,6 +126,7 @@ public static partial class LuaVirtualMachine
                 context.Thread.PopCallStackFrameWithStackPop();
                 context.Thread.PopCallStackFrameWithStackPop();
             }
             }
         }
         }
+        context.Thread.ThrowIfCancellationRequested(cancellationToken);
 
 
         {
         {
             var frame = context.Thread.GetCurrentFrame();
             var frame = context.Thread.GetCurrentFrame();
@@ -135,7 +136,7 @@ public static partial class LuaVirtualMachine
             {
             {
                 return r;
                 return r;
             }
             }
-
+            context.Thread.ThrowIfCancellationRequested(cancellationToken);
             var top = stack.Count;
             var top = stack.Count;
             stack.Push("return");
             stack.Push("return");
             stack.Push(LuaValue.Nil);
             stack.Push(LuaValue.Nil);

+ 61 - 14
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -1,6 +1,7 @@
 using System.Diagnostics.CodeAnalysis;
 using System.Diagnostics.CodeAnalysis;
 using System.Runtime.CompilerServices;
 using System.Runtime.CompilerServices;
 using Lua.Internal;
 using Lua.Internal;
+// ReSharper disable MethodHasAsyncOverload
 
 
 // ReSharper disable InconsistentNaming
 // ReSharper disable InconsistentNaming
 
 
@@ -255,6 +256,8 @@ public static partial class LuaVirtualMachine
                     {
                     {
                         break;
                         break;
                     }
                     }
+
+                    ThrowIfCancellationRequested();
                 }
                 }
 
 
                 return Thread.Stack.Count - returnFrameBase;
                 return Thread.Stack.Count - returnFrameBase;
@@ -264,6 +267,17 @@ public static partial class LuaVirtualMachine
                 pool.TryPush(this);
                 pool.TryPush(this);
             }
             }
         }
         }
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public void ThrowIfCancellationRequested()
+        {
+            if (!CancellationToken.IsCancellationRequested) return;
+            Throw();
+                
+            void Throw()
+            { 
+                GetThreadWithCurrentPc(this).ThrowIfCancellationRequested(CancellationToken);
+            }
+        }
     }
     }
 
 
     enum PostOperationType
     enum PostOperationType
@@ -567,6 +581,7 @@ public static partial class LuaVirtualMachine
                             context.Thread.State.CloseUpValues(context.Thread, frameBase + iA - 1);
                             context.Thread.State.CloseUpValues(context.Thread, frameBase + iA - 1);
                         }
                         }
 
 
+                        context.ThrowIfCancellationRequested();
                         continue;
                         continue;
                     case OpCode.Eq:
                     case OpCode.Eq:
                         Markers.Eq();
                         Markers.Eq();
@@ -695,6 +710,7 @@ public static partial class LuaVirtualMachine
                             indexRef = index;
                             indexRef = index;
                             Unsafe.Add(ref indexRef, 3) = index;
                             Unsafe.Add(ref indexRef, 3) = index;
                             stack.NotifyTop(iA + frameBase + 4);
                             stack.NotifyTop(iA + frameBase + 4);
+                            context.ThrowIfCancellationRequested();
                             continue;
                             continue;
                         }
                         }
 
 
@@ -799,7 +815,7 @@ public static partial class LuaVirtualMachine
         catch (Exception e)
         catch (Exception e)
         {
         {
             context.State.CloseUpValues(context.Thread, context.FrameBase);
             context.State.CloseUpValues(context.Thread, context.FrameBase);
-            if (e is not LuaRuntimeException)
+            if (e is not (LuaRuntimeException or LuaCancelledException))
             {
             {
                 var newException = new LuaRuntimeException(context.Thread, e);
                 var newException = new LuaRuntimeException(context.Thread, e);
                 context.PopOnTopCallStackFrames();
                 context.PopOnTopCallStackFrames();
@@ -1003,6 +1019,7 @@ public static partial class LuaVirtualMachine
     static async ValueTask ExecuteBinaryOperationMetaMethod(int target, LuaValue vb, LuaValue vc,
     static async ValueTask ExecuteBinaryOperationMetaMethod(int target, LuaValue vb, LuaValue vc,
         VirtualMachineExecutionContext context, OpCode opCode)
         VirtualMachineExecutionContext context, OpCode opCode)
     {
     {
+        context.ThrowIfCancellationRequested();
         var (name, description) = opCode.GetNameAndDescription();
         var (name, description) = opCode.GetNameAndDescription();
         if (vb.TryGetMetamethod(context.State, name, out var metamethod) ||
         if (vb.TryGetMetamethod(context.State, name, out var metamethod) ||
             vc.TryGetMetamethod(context.State, name, out metamethod))
             vc.TryGetMetamethod(context.State, name, out metamethod))
@@ -1039,12 +1056,14 @@ public static partial class LuaVirtualMachine
                     await ExecuteCallHook(functionContext, context.CancellationToken);
                     await ExecuteCallHook(functionContext, context.CancellationToken);
                     stack.PopUntil(target + 1);
                     stack.PopUntil(target + 1);
                     context.PostOperation = PostOperationType.DontPop;
                     context.PostOperation = PostOperationType.DontPop;
+                    context.ThrowIfCancellationRequested();
                     return;
                     return;
                 }
                 }
 
 
                 await func.Func(functionContext, context.CancellationToken);
                 await func.Func(functionContext, context.CancellationToken);
                 stack.PopUntil(target + 1);
                 stack.PopUntil(target + 1);
                 context.PostOperation = PostOperationType.DontPop;
                 context.PostOperation = PostOperationType.DontPop;
+                context.ThrowIfCancellationRequested();
                 return;
                 return;
             }
             }
             finally
             finally
@@ -1060,6 +1079,7 @@ public static partial class LuaVirtualMachine
 
 
     static bool Call(VirtualMachineExecutionContext context, out bool doRestart)
     static bool Call(VirtualMachineExecutionContext context, out bool doRestart)
     {
     {
+        context.ThrowIfCancellationRequested();
         var instruction = context.Instruction;
         var instruction = context.Instruction;
         var RA = instruction.A + context.FrameBase;
         var RA = instruction.A + context.FrameBase;
         var newBase = RA + 1;
         var newBase = RA + 1;
@@ -1119,6 +1139,7 @@ public static partial class LuaVirtualMachine
             var awaiter = task.GetAwaiter();
             var awaiter = task.GetAwaiter();
 
 
             awaiter.GetResult();
             awaiter.GetResult();
+            context.Thread.ThrowIfCancellationRequested(context.CancellationToken);
             var instruction = context.Instruction;
             var instruction = context.Instruction;
             var ic = instruction.C;
             var ic = instruction.C;
 
 
@@ -1137,8 +1158,9 @@ public static partial class LuaVirtualMachine
         }
         }
     }
     }
 
 
-    internal static async ValueTask<int> Call(LuaThread thread, int funcIndex, int returnBase, CancellationToken ct)
+    internal static async ValueTask<int> Call(LuaThread thread, int funcIndex, int returnBase, CancellationToken cancellationToken)
     {
     {
+        thread.ThrowIfCancellationRequested(cancellationToken);
         var stack = thread.Stack;
         var stack = thread.Stack;
         var newBase = funcIndex + 1;
         var newBase = funcIndex + 1;
         var va = stack.Get(funcIndex);
         var va = stack.Get(funcIndex);
@@ -1165,12 +1187,24 @@ public static partial class LuaVirtualMachine
             var functionContext = new LuaFunctionExecutionContext() { Access = access, ArgumentCount = argCount, ReturnFrameBase = returnBase };
             var functionContext = new LuaFunctionExecutionContext() { Access = access, ArgumentCount = argCount, ReturnFrameBase = returnBase };
             if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
             if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
             {
             {
-                await ExecuteCallHook(functionContext, ct);
+                await ExecuteCallHook(functionContext, cancellationToken);
+            }
+            else
+            {
+                await func.Func(functionContext, cancellationToken);
             }
             }
 
 
-            await func.Func(functionContext, ct);
+            thread.ThrowIfCancellationRequested(cancellationToken);
             return thread.Stack.Count - funcIndex;
             return thread.Stack.Count - funcIndex;
         }
         }
+        catch(OperationCanceledException  operationCanceledException)
+        {
+            if(operationCanceledException is not LuaCancelledException)
+            {
+                throw new LuaCancelledException(thread, cancellationToken, operationCanceledException);
+            }
+            throw;
+        }
         finally
         finally
         {
         {
             thread.PopCallStackFrame();
             thread.PopCallStackFrame();
@@ -1195,6 +1229,7 @@ public static partial class LuaVirtualMachine
 
 
     static bool TailCall(VirtualMachineExecutionContext context, out bool doRestart)
     static bool TailCall(VirtualMachineExecutionContext context, out bool doRestart)
     {
     {
+        context.ThrowIfCancellationRequested();
         var instruction = context.Instruction;
         var instruction = context.Instruction;
         var stack = context.Stack;
         var stack = context.Stack;
         var RA = instruction.A + context.FrameBase;
         var RA = instruction.A + context.FrameBase;
@@ -1259,6 +1294,7 @@ public static partial class LuaVirtualMachine
 
 
 
 
         task.GetAwaiter().GetResult();
         task.GetAwaiter().GetResult();
+        context.ThrowIfCancellationRequested();
         if (!context.PopFromBuffer(context.CurrentReturnFrameBase, context.Stack.Count - context.CurrentReturnFrameBase))
         if (!context.PopFromBuffer(context.CurrentReturnFrameBase, context.Stack.Count - context.CurrentReturnFrameBase))
         {
         {
             return true;
             return true;
@@ -1270,6 +1306,7 @@ public static partial class LuaVirtualMachine
 
 
     static bool TForCall(VirtualMachineExecutionContext context, out bool doRestart)
     static bool TForCall(VirtualMachineExecutionContext context, out bool doRestart)
     {
     {
+        context.ThrowIfCancellationRequested();
         doRestart = false;
         doRestart = false;
         var instruction = context.Instruction;
         var instruction = context.Instruction;
         var stack = context.Stack;
         var stack = context.Stack;
@@ -1341,8 +1378,8 @@ public static partial class LuaVirtualMachine
             return false;
             return false;
         }
         }
 
 
-        var awaiter = task.GetAwaiter();
-        awaiter.GetResult();
+        task.GetAwaiter().GetResult();
+        context.ThrowIfCancellationRequested();
         context.Thread.PopCallStackFrame();
         context.Thread.PopCallStackFrame();
         TForCallPostOperation(context);
         TForCallPostOperation(context);
         return true;
         return true;
@@ -1942,8 +1979,10 @@ public static partial class LuaVirtualMachine
     }
     }
 
 
     [MethodImpl(MethodImplOptions.NoInlining)]
     [MethodImpl(MethodImplOptions.NoInlining)]
-    internal static async ValueTask<LuaValue> ExecuteUnaryOperationMetaMethod(LuaThread thread, LuaValue vb, OpCode opCode, CancellationToken ct)
+    internal static async ValueTask<LuaValue> ExecuteUnaryOperationMetaMethod(LuaThread thread, LuaValue vb, OpCode opCode, CancellationToken cancellationToken)
     {
     {
+        thread.ThrowIfCancellationRequested(cancellationToken);
+
         var (name, description) = opCode.GetNameAndDescription();
         var (name, description) = opCode.GetNameAndDescription();
 
 
         if (vb.TryGetMetamethod(thread.State, name, out var metamethod))
         if (vb.TryGetMetamethod(thread.State, name, out var metamethod))
@@ -1976,11 +2015,14 @@ public static partial class LuaVirtualMachine
                 var functionContext = new LuaFunctionExecutionContext() { Access = access, ArgumentCount = argCount, ReturnFrameBase = newBase };
                 var functionContext = new LuaFunctionExecutionContext() { Access = access, ArgumentCount = argCount, ReturnFrameBase = newBase };
                 if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
                 if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
                 {
                 {
-                    await ExecuteCallHook(functionContext, ct);
+                    await ExecuteCallHook(functionContext, cancellationToken);
+                }
+                else
+                {
+                    await func.Func(functionContext, cancellationToken);
                 }
                 }
 
 
-
-                await func.Func(functionContext, ct);
+                thread.ThrowIfCancellationRequested(cancellationToken);
                 var results = stack.GetBuffer()[newFrame.ReturnBase..];
                 var results = stack.GetBuffer()[newFrame.ReturnBase..];
                 var result = results.Length == 0 ? default : results[0];
                 var result = results.Length == 0 ? default : results[0];
                 results.Clear();
                 results.Clear();
@@ -2097,8 +2139,10 @@ public static partial class LuaVirtualMachine
     }
     }
 
 
     [MethodImpl(MethodImplOptions.NoInlining)]
     [MethodImpl(MethodImplOptions.NoInlining)]
-    internal static async ValueTask<bool> ExecuteCompareOperationMetaMethod(LuaThread thread, LuaValue vb, LuaValue vc, OpCode opCode, CancellationToken ct)
+    internal static async ValueTask<bool> ExecuteCompareOperationMetaMethod(LuaThread thread, LuaValue vb, LuaValue vc, OpCode opCode, CancellationToken cancellationToken)
     {
     {
+        thread.ThrowIfCancellationRequested(cancellationToken);
+
         var (name, description) = opCode.GetNameAndDescription();
         var (name, description) = opCode.GetNameAndDescription();
         bool reverseLe = false;
         bool reverseLe = false;
     ReCheck:
     ReCheck:
@@ -2133,11 +2177,14 @@ public static partial class LuaVirtualMachine
                 var functionContext = new LuaFunctionExecutionContext() { Access = access, ArgumentCount = argCount, ReturnFrameBase = newBase };
                 var functionContext = new LuaFunctionExecutionContext() { Access = access, ArgumentCount = argCount, ReturnFrameBase = newBase };
                 if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
                 if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook)
                 {
                 {
-                    await ExecuteCallHook(functionContext, ct);
+                    await ExecuteCallHook(functionContext, cancellationToken);
+                }
+                else
+                {
+                    await func.Func(functionContext, cancellationToken);
                 }
                 }
 
 
-
-                await func.Func(functionContext, ct);
+                thread.ThrowIfCancellationRequested(cancellationToken);
                 var results = stack.GetBuffer()[newFrame.ReturnBase..];
                 var results = stack.GetBuffer()[newFrame.ReturnBase..];
                 var result = results.Length == 0 ? default : results[0];
                 var result = results.Length == 0 ? default : results[0];
                 results.Clear();
                 results.Clear();

+ 12 - 7
src/Lua/Standard/BasicLibrary.cs

@@ -1,6 +1,7 @@
 using System.Globalization;
 using System.Globalization;
 using Lua.Internal;
 using Lua.Internal;
 using Lua.Runtime;
 using Lua.Runtime;
+
 // ReSharper disable MethodHasAsyncOverloadWithCancellation
 // ReSharper disable MethodHasAsyncOverloadWithCancellation
 
 
 namespace Lua.Standard;
 namespace Lua.Standard;
@@ -259,14 +260,17 @@ public sealed class BasicLibrary
         catch (Exception ex)
         catch (Exception ex)
         {
         {
             context.Thread.PopCallStackFrameUntil(frameCount);
             context.Thread.PopCallStackFrameUntil(frameCount);
-            if (ex is LuaRuntimeException luaEx)
-            {
-                luaEx.Forget();
-                return context.Return(false, luaEx.ErrorObject);
-            }
-            else
+            switch (ex)
             {
             {
-                return context.Return(false, ex.Message);
+                case LuaCancelledException:
+                    throw;
+                case OperationCanceledException:
+                    throw new LuaCancelledException(context.Thread,cancellationToken, ex);
+                case LuaRuntimeException luaEx:
+                    luaEx.Forget();
+                    return context.Return(false, luaEx.ErrorObject);
+                default:
+                    return context.Return(false, ex.Message);
             }
             }
         }
         }
     }
     }
@@ -565,6 +569,7 @@ public sealed class BasicLibrary
         {
         {
             var thread = context.Thread;
             var thread = context.Thread;
             thread.PopCallStackFrameUntil(frameCount);
             thread.PopCallStackFrameUntil(frameCount);
+            cancellationToken.ThrowIfCancellationRequested();
 
 
             var access = thread.CurrentAccess;
             var access = thread.CurrentAccess;
             if (ex is LuaRuntimeException luaEx)
             if (ex is LuaRuntimeException luaEx)

+ 180 - 0
tests/Lua.Tests/CancellationTest.cs

@@ -0,0 +1,180 @@
+using Lua.Standard;
+
+namespace Lua.Tests;
+
+public class CancellationTest
+{
+    LuaState state = default!;
+
+    [SetUp]
+    public void SetUp()
+    {
+        state = LuaState.Create();
+        state.OpenStandardLibraries();
+
+        state.Environment["assert"] = new LuaFunction("assert_with_wait",
+            async (context, ct) =>
+            {
+                await Task.Delay(1, ct);
+                var arg0 = context.GetArgument(0);
+
+                if (!arg0.ToBoolean())
+                {
+                    var message = "assertion failed!";
+                    if (context.HasArgument(1))
+                    {
+                        message = context.GetArgument<string>(1);
+                    }
+
+                    throw new LuaAssertionException(context.Thread, message);
+                }
+
+                return (context.Return(context.Arguments));
+            });
+        state.Environment["sleep"] = new LuaFunction("sleep",
+            (context, _) =>
+            {
+                Thread.Sleep(context.GetArgument<int>(0));
+
+                return new(context.Return());
+            });
+        state.Environment["wait"] = new LuaFunction("wait",
+            async (context, ct) =>
+            {
+                await Task.Delay(context.GetArgument<int>(0), ct);
+                return context.Return();
+            });
+    }
+
+    [Test]
+    public async Task PCall_WaitTest()
+    {
+        var source = """
+                     local function f(millisec)
+                         wait(millisec)
+                     end                     
+                     pcall(f, 500)
+                     """;
+        var cancellationTokenSource = new CancellationTokenSource();
+        cancellationTokenSource.CancelAfter(200);
+
+        try
+        {
+            await state.DoStringAsync(source, "@test.lua", cancellationTokenSource.Token);
+            Assert.Fail("Expected TaskCanceledException was not thrown.");
+        }
+        catch (Exception e)
+        {
+            Assert.That(e, Is.TypeOf<LuaCancelledException>());
+            var luaCancelledException = (LuaCancelledException)e;
+            Assert.That(luaCancelledException.InnerException, Is.TypeOf<TaskCanceledException>());
+            var luaStackTrace = luaCancelledException.LuaTraceback!.ToString();
+            Console.WriteLine(luaStackTrace);
+            Assert.That(luaStackTrace, Contains.Substring("'wait'"));
+            Assert.That(luaStackTrace, Contains.Substring("'pcall'"));
+        }
+    }
+
+    [Test]
+    public async Task PCall_SleepTest()
+    {
+        var source = """
+                     local function f(millisec)
+                         sleep(millisec)
+                     end                     
+                     pcall(f, 500)
+                     """;
+        var cancellationTokenSource = new CancellationTokenSource();
+        cancellationTokenSource.CancelAfter(250);
+
+        try
+        {
+            await state.DoStringAsync(source, "@test.lua", cancellationTokenSource.Token);
+            Assert.Fail("Expected TaskCanceledException was not thrown.");
+        }
+        catch (Exception e)
+        {
+            Assert.That(e, Is.TypeOf<LuaCancelledException>());
+            var luaCancelledException = (LuaCancelledException)e;
+            Assert.That(luaCancelledException.InnerException, Is.Null);
+            var luaStackTrace = luaCancelledException.LuaTraceback!.ToString();
+            Console.WriteLine(luaStackTrace);
+            Assert.That(luaStackTrace, Contains.Substring("'sleep'"));
+            Assert.That(luaStackTrace, Contains.Substring("'pcall'"));
+        }
+    }
+
+    [Test]
+    public async Task ForLoopTest()
+    {
+        var source = """
+                     local ret = 0
+                     for i = 1, 1000000000 do
+                         ret = ret + i
+                     end
+                     return ret
+                     """;
+        var cancellationTokenSource = new CancellationTokenSource();
+        cancellationTokenSource.CancelAfter(100);
+        cancellationTokenSource.Token.Register(() =>
+        {
+            Console.WriteLine("Cancellation requested");
+        });
+        try
+        {
+            var r = await state.DoStringAsync(source, "@test.lua", cancellationTokenSource.Token);
+            Console.WriteLine(r[0]);
+            Assert.Fail("Expected TaskCanceledException was not thrown.");
+        }
+        catch (Exception e)
+        {
+            Assert.That(e, Is.TypeOf<LuaCancelledException>());
+            Console.WriteLine(e.StackTrace);
+            var luaCancelledException = (LuaCancelledException)e;
+            Assert.That(luaCancelledException.InnerException, Is.Null);
+            var traceback = luaCancelledException.LuaTraceback;
+            if (traceback != null)
+            {
+                var luaStackTrace = traceback.ToString();
+                Console.WriteLine(luaStackTrace);
+            }
+        }
+    }
+    
+    [Test]
+    public async Task GoToLoopTest()
+    {
+        var source = """
+                     local ret = 0
+                     ::loop::
+                     ret = ret + 1
+                     goto loop
+                     return ret
+                     """;
+        var cancellationTokenSource = new CancellationTokenSource();
+        cancellationTokenSource.CancelAfter(100);
+        cancellationTokenSource.Token.Register(() =>
+        {
+            Console.WriteLine("Cancellation requested");
+        });
+        try
+        {
+            var r = await state.DoStringAsync(source, "@test.lua", cancellationTokenSource.Token);
+            Console.WriteLine(r[0]);
+            Assert.Fail("Expected TaskCanceledException was not thrown.");
+        }
+        catch (Exception e)
+        {
+            Assert.That(e, Is.TypeOf<LuaCancelledException>());
+            Console.WriteLine(e.StackTrace);
+            var luaCancelledException = (LuaCancelledException)e;
+            Assert.That(luaCancelledException.InnerException, Is.Null);
+            var traceback = luaCancelledException.LuaTraceback;
+            if (traceback != null)
+            {
+                var luaStackTrace = traceback.ToString();
+                Console.WriteLine(luaStackTrace);
+            }
+        }
+    }
+}