Browse Source

Fix: upvalues ​​not working in coroutines

AnnulusGames 1 year ago
parent
commit
a5118d0824

+ 43 - 0
src/Lua/Internal/FastList.cs

@@ -0,0 +1,43 @@
+using System.Runtime.CompilerServices;
+
+namespace Lua.Internal;
+
+public class FastList<T>
+{
+    FastListCore<T> core;
+
+    public int Length
+    {
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        get
+        {
+            return core.Length;
+        }
+    }
+
+    public ref T this[int index]
+    {
+        get
+        {
+            return ref core[index];
+        }
+    }
+
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    public void Add(T value)
+    {
+        core.Add(value);
+    }
+
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    public void RemoveAtSwapback(int index)
+    {
+        core.RemoveAtSwapback(index);
+    }
+
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    public Span<T> AsSpan()
+    {
+        return core.AsSpan();
+    }
+}

+ 2 - 4
src/Lua/Internal/FastListCore.cs

@@ -85,12 +85,10 @@ public struct FastListCore<T>
         AsSpan().CopyTo(destination.AsSpan());
     }
 
-    public readonly T this[int index]
+    public ref T this[int index]
     {
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
-        get => array![index];
-        [MethodImpl(MethodImplOptions.AggressiveInlining)]
-        set => array![index] = value;
+        get => ref array![index];
     }
 
     public readonly int Length

+ 6 - 1
src/Lua/LuaFunction.cs

@@ -4,9 +4,14 @@ namespace Lua;
 
 public abstract partial class LuaFunction
 {
-    internal LuaThread? thread;
+    LuaThread? thread;
     public LuaThread? Thread => thread;
 
+    internal void SetCurrentThread(LuaThread thread)
+    {
+        this.thread = thread;
+    }
+
     public virtual string Name => GetType().Name;
 
     public async ValueTask<int> InvokeAsync(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)

+ 10 - 6
src/Lua/LuaState.cs

@@ -10,7 +10,7 @@ public sealed class LuaState
 
     LuaStack stack = new();
     FastStackCore<CallStackFrame> callStack;
-    FastListCore<UpValue> openUpValues;
+    FastList<UpValue> openUpValues;
 
     LuaTable environment;
     internal UpValue EnvUpValue { get; }
@@ -34,13 +34,15 @@ public sealed class LuaState
     LuaState()
     {
         environment = new();
-        EnvUpValue = UpValue.Closed(environment);
+        EnvUpValue = UpValue.Closed(this, environment);
+        openUpValues = new();
     }
 
     LuaState(LuaState parent)
     {
         environment = parent.Environment;
         EnvUpValue = parent.EnvUpValue;
+        openUpValues = parent.openUpValues;
     }
 
     public async ValueTask<int> RunAsync(Chunk chunk, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
@@ -99,7 +101,7 @@ public sealed class LuaState
     public bool TryGetCurrentThread([NotNullWhen(true)] out LuaThread? result)
     {
         var span = GetCallStackSpan();
-        
+
         for (int i = 0; i < span.Length; i++)
         {
             result = span[i].Function.Thread;
@@ -127,13 +129,13 @@ public sealed class LuaState
     {
         foreach (var upValue in openUpValues.AsSpan())
         {
-            if (upValue.RegisterIndex == registerIndex)
+            if (upValue.RegisterIndex == registerIndex && upValue.State == this)
             {
                 return upValue;
             }
         }
 
-        var newUpValue = UpValue.Open(registerIndex);
+        var newUpValue = UpValue.Open(this, registerIndex);
         openUpValues.Add(newUpValue);
         return newUpValue;
     }
@@ -143,9 +145,11 @@ public sealed class LuaState
         for (int i = 0; i < openUpValues.Length; i++)
         {
             var upValue = openUpValues[i];
+            if (upValue.State != this) continue;
+
             if (upValue.RegisterIndex >= frameBase)
             {
-                upValue.Close(this);
+                upValue.Close();
                 openUpValues.RemoveAtSwapback(i);
                 i--;
             }

+ 3 - 4
src/Lua/LuaThread.cs

@@ -3,7 +3,6 @@ namespace Lua;
 public sealed class LuaThread
 {
     LuaThreadStatus status;
-    bool isProtectedMode;
     LuaState threadState;
     Task<int>? functionTask;
 
@@ -11,15 +10,15 @@ public sealed class LuaThread
     TaskCompletionSource<object?> yield = new();
 
     public LuaThreadStatus Status => status;
-    public bool IsProtectedMode => isProtectedMode;
+    public bool IsProtectedMode { get; }
     public LuaFunction Function { get; }
 
     internal LuaThread(LuaState state, LuaFunction function, bool isProtectedMode)
     {
-        this.isProtectedMode = isProtectedMode;
+        IsProtectedMode = isProtectedMode;
         threadState = state.CreateCoroutineState();
         Function = function;
-        function.thread = this;
+        function.SetCurrentThread(this);
     }
 
     public async Task<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)

+ 4 - 4
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -47,7 +47,7 @@ public static partial class LuaVirtualMachine
                     {
                         stack.EnsureCapacity(RA + 1);
                         var upValue = closure.UpValues[instruction.B];
-                        stack.UnsafeGet(RA) = upValue.GetValue(state);
+                        stack.UnsafeGet(RA) = upValue.GetValue();
                         stack.NotifyTop(RA + 1);
                         break;
                     }
@@ -56,7 +56,7 @@ public static partial class LuaVirtualMachine
                         stack.EnsureCapacity(RA + 1);
                         var vc = RK(stack, chunk, instruction.C, frame.Base);
                         var upValue = closure.UpValues[instruction.B];
-                        var table = upValue.GetValue(state);
+                        var table = upValue.GetValue();
                         var value = await GetTableValue(state, chunk, pc, table, vc, cancellationToken);
                         stack.UnsafeGet(RA) = value;
                         stack.NotifyTop(RA + 1);
@@ -78,14 +78,14 @@ public static partial class LuaVirtualMachine
                         var vc = RK(stack, chunk, instruction.C, frame.Base);
 
                         var upValue = closure.UpValues[instruction.A];
-                        var table = upValue.GetValue(state);
+                        var table = upValue.GetValue();
                         await SetTableValue(state, chunk, pc, table, vb, vc, cancellationToken);
                         break;
                     }
                 case OpCode.SetUpVal:
                     {
                         var upValue = closure.UpValues[instruction.B];
-                        upValue.SetValue(state, stack.UnsafeGet(RA));
+                        upValue.SetValue(stack.UnsafeGet(RA));
                         break;
                     }
                 case OpCode.SetTable:

+ 13 - 11
src/Lua/Runtime/UpValue.cs

@@ -7,24 +7,26 @@ public sealed class UpValue
 {
     LuaValue value;
 
+    public LuaState State { get; }
     public bool IsClosed { get; private set; }
     public int RegisterIndex { get; private set; }
 
-    UpValue()
+    UpValue(LuaState state)
     {
+        State = state;
     }
 
-    public static UpValue Open(int registerIndex)
+    public static UpValue Open(LuaState state, int registerIndex)
     {
-        return new()
+        return new(state)
         {
             RegisterIndex = registerIndex
         };
     }
 
-    public static UpValue Closed(LuaValue value)
+    public static UpValue Closed(LuaState state, LuaValue value)
     {
-        return new()
+        return new(state)
         {
             IsClosed = true,
             value = value
@@ -32,7 +34,7 @@ public sealed class UpValue
     }
 
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    public LuaValue GetValue(LuaState state)
+    public LuaValue GetValue()
     {
         if (IsClosed)
         {
@@ -40,12 +42,12 @@ public sealed class UpValue
         }
         else
         {
-            return state.Stack.UnsafeGet(RegisterIndex);
+            return State.Stack.UnsafeGet(RegisterIndex);
         }
     }
 
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    public void SetValue(LuaState state, LuaValue value)
+    public void SetValue(LuaValue value)
     {
         if (IsClosed)
         {
@@ -53,16 +55,16 @@ public sealed class UpValue
         }
         else
         {
-            state.Stack.UnsafeGet(RegisterIndex) = value;
+            State.Stack.UnsafeGet(RegisterIndex) = value;
         }
     }
 
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    public void Close(LuaState state)
+    public void Close()
     {
         if (!IsClosed)
         {
-            value = state.Stack.UnsafeGet(RegisterIndex);
+            value = State.Stack.UnsafeGet(RegisterIndex);
         }
 
         IsClosed = true;