Browse Source

refactor: replace File.ReadAllBytesAsync with sync to prevent unintended Threading

Akeit0 7 months ago
parent
commit
40ec0535d5
2 changed files with 20 additions and 19 deletions
  1. 4 2
      src/Lua/Runtime/LuaThreadAccessExtensions.cs
  2. 16 17
      src/Lua/Standard/BasicLibrary.cs

+ 4 - 2
src/Lua/Runtime/LuaThreadAccessExtensions.cs

@@ -1,5 +1,7 @@
 using System.Runtime.CompilerServices;
 
+// ReSharper disable MethodHasAsyncOverloadWithCancellation
+
 namespace Lua.Runtime;
 
 public static class LuaThreadAccessAccessExtensions
@@ -26,7 +28,7 @@ public static class LuaThreadAccessAccessExtensions
     public static async ValueTask<int> DoFileAsync(this LuaThreadAccess access, string path, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
     {
         access.ThrowIfInvalid();
-        var bytes = await File.ReadAllBytesAsync(path, cancellationToken);
+        var bytes = File.ReadAllBytes(path);
         var fileName = "@" + path;
         var closure = access.State.Load(bytes, fileName);
         var count = await access.RunAsync(closure, 0, cancellationToken);
@@ -37,7 +39,7 @@ public static class LuaThreadAccessAccessExtensions
 
     public static async ValueTask<LuaValue[]> DoFileAsync(this LuaThreadAccess access, string path, CancellationToken cancellationToken = default)
     {
-        var bytes = await File.ReadAllBytesAsync(path, cancellationToken);
+        var bytes = File.ReadAllBytes(path);
         var fileName = "@" + path;
         var closure = access.State.Load(bytes, fileName);
         var count = await access.RunAsync(closure, 0, cancellationToken);

+ 16 - 17
src/Lua/Standard/BasicLibrary.cs

@@ -1,6 +1,7 @@
 using System.Globalization;
 using Lua.Internal;
 using Lua.Runtime;
+// ReSharper disable MethodHasAsyncOverloadWithCancellation
 
 namespace Lua.Standard;
 
@@ -88,11 +89,10 @@ public sealed class BasicLibrary
         var arg0 = context.GetArgument<string>(0);
         context.Thread.Stack.PopUntil(context.ReturnFrameBase);
 
-        // do not use LuaState.DoFileAsync as it uses the newExecutionContext
-        var bytes =   File.ReadAllBytes(arg0);
+        var bytes = File.ReadAllBytes(arg0);
         var fileName = "@" + arg0;
         var closure = context.State.Load(bytes, fileName);
-        return await context.Access.RunAsync(closure,cancellationToken);
+        return await context.Access.RunAsync(closure, cancellationToken);
     }
 
     public ValueTask<int> Error(LuaFunctionExecutionContext context, CancellationToken cancellationToken)
@@ -139,15 +139,15 @@ public sealed class BasicLibrary
             stack.Push(metamethod);
             stack.Push(arg0);
 
-            await LuaVirtualMachine.Call(context.Access.Thread,top,context.ReturnFrameBase,cancellationToken);
-            stack.SetTop(context.ReturnFrameBase+3);
+            await LuaVirtualMachine.Call(context.Access.Thread, top, context.ReturnFrameBase, cancellationToken);
+            stack.SetTop(context.ReturnFrameBase + 3);
             return 3;
         }
 
         return context.Return(IPairsIterator, arg0, 0);
     }
 
-    public async ValueTask<int> LoadFile(LuaFunctionExecutionContext context, CancellationToken cancellationToken)
+    public ValueTask<int> LoadFile(LuaFunctionExecutionContext context, CancellationToken cancellationToken)
     {
         var arg0 = context.GetArgument<string>(0);
         var mode = context.HasArgument(1)
@@ -160,13 +160,13 @@ public sealed class BasicLibrary
         // do not use LuaState.DoFileAsync as it uses the newExecutionContext
         try
         {
-            var bytes = await File.ReadAllBytesAsync(arg0, cancellationToken);
+            var bytes = File.ReadAllBytes(arg0);
             var fileName = "@" + arg0;
-            return context.Return(context.State.Load(bytes, fileName, mode, arg2));
+            return new(context.Return(context.State.Load(bytes, fileName, mode, arg2)));
         }
         catch (Exception ex)
         {
-            return context.Return(LuaValue.Nil, ex.Message);
+            return new(context.Return(LuaValue.Nil, ex.Message));
         }
     }
 
@@ -238,10 +238,9 @@ public sealed class BasicLibrary
             stack.Push(metamethod);
             stack.Push(arg0);
 
-            await LuaVirtualMachine.Call(context.Access.Thread,top,context.ReturnFrameBase,cancellationToken);
-            stack.SetTop(context.ReturnFrameBase+3);
+            await LuaVirtualMachine.Call(context.Access.Thread, top, context.ReturnFrameBase, cancellationToken);
+            stack.SetTop(context.ReturnFrameBase + 3);
             return 3;
-            
         }
 
         return (context.Return(PairsIterator, arg0, LuaValue.Nil));
@@ -252,7 +251,7 @@ public sealed class BasicLibrary
         var frameCount = context.Thread.CallStackFrameCount;
         try
         {
-            var count =  await LuaVirtualMachine.Call(context.Access.Thread,context.FrameBase,context.ReturnFrameBase+1,cancellationToken);
+            var count = await LuaVirtualMachine.Call(context.Access.Thread, context.FrameBase, context.ReturnFrameBase + 1, cancellationToken);
 
             context.Thread.Stack.Get(context.ReturnFrameBase) = true;
             return count + 1;
@@ -556,8 +555,8 @@ public sealed class BasicLibrary
         try
         {
             var stack = context.Thread.Stack;
-            stack.Get(context.FrameBase+1) = arg0;
-            var count =  await LuaVirtualMachine.Call(context.Access.Thread,context.FrameBase + 1,context.ReturnFrameBase+1,cancellationToken);
+            stack.Get(context.FrameBase + 1) = arg0;
+            var count = await LuaVirtualMachine.Call(context.Access.Thread, context.FrameBase + 1, context.ReturnFrameBase + 1, cancellationToken);
 
             context.Thread.Stack.Get(context.ReturnFrameBase) = true;
             return count + 1;
@@ -566,7 +565,7 @@ public sealed class BasicLibrary
         {
             var thread = context.Thread;
             thread.PopCallStackFrameUntil(frameCount);
-            
+
             var access = thread.CurrentAccess;
             if (ex is LuaRuntimeException luaEx)
             {
@@ -580,7 +579,7 @@ public sealed class BasicLibrary
 
 
             // invoke error handler
-            var count = await access.RunAsync(arg1, 1,context.ReturnFrameBase+1, cancellationToken);
+            var count = await access.RunAsync(arg1, 1, context.ReturnFrameBase + 1, cancellationToken);
             context.Thread.Stack.Get(context.ReturnFrameBase) = false;
             return count + 1;
         }