Browse Source

Fix: closure upvalues

Akeit0 1 year ago
parent
commit
2f504aa558

+ 16 - 12
src/Lua/CodeAnalysis/Compilation/LuaCompiler.cs

@@ -255,23 +255,26 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
                         // For the last element, we need to take into account variable arguments and multiple return values.
                         // For the last element, we need to take into account variable arguments and multiple return values.
                         if (listItem == lastField)
                         if (listItem == lastField)
                         {
                         {
+                            bool isFixedItems = true;
                             switch (listItem.Expression)
                             switch (listItem.Expression)
                             {
                             {
                                 case CallFunctionExpressionNode call:
                                 case CallFunctionExpressionNode call:
                                     CompileCallFunctionExpression(call, context, false, -1);
                                     CompileCallFunctionExpression(call, context, false, -1);
+                                    isFixedItems = false;
                                     break;
                                     break;
                                 case CallTableMethodExpressionNode method:
                                 case CallTableMethodExpressionNode method:
                                     CompileTableMethod(method, context, false, -1);
                                     CompileTableMethod(method, context, false, -1);
                                     break;
                                     break;
                                 case VariableArgumentsExpressionNode varArg:
                                 case VariableArgumentsExpressionNode varArg:
                                     CompileVariableArgumentsExpression(varArg, context, -1);
                                     CompileVariableArgumentsExpression(varArg, context, -1);
+                                    isFixedItems = false;
                                     break;
                                     break;
                                 default:
                                 default:
                                     listItem.Expression.Accept(this, context);
                                     listItem.Expression.Accept(this, context);
                                     break;
                                     break;
                             }
                             }
 
 
-                            context.PushInstruction(Instruction.SetList(tableRegisterIndex, 0, arrayBlock), listItem.Position);
+                            context.PushInstruction(Instruction.SetList(tableRegisterIndex, (ushort)(isFixedItems ? context.StackTopPosition - tableRegisterIndex: 0), arrayBlock), listItem.Position);
                             currentArrayChunkSize = 0;
                             currentArrayChunkSize = 0;
                         }
                         }
                         else
                         else
@@ -736,7 +739,7 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
     {
     {
         using var endJumpIndexList = new PooledList<int>(8);
         using var endJumpIndexList = new PooledList<int>(8);
         var hasElse = node.ElseNodes.Length > 0;
         var hasElse = node.ElseNodes.Length > 0;
-
+        var stackPositionToClose = (byte)(context.StackPosition + 1);
         // if
         // if
         using (var scopeContext = context.CreateChildScope())
         using (var scopeContext = context.CreateChildScope())
         {
         {
@@ -750,15 +753,15 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
                 childNode.Accept(this, scopeContext);
                 childNode.Accept(this, scopeContext);
             }
             }
 
 
+            stackPositionToClose =scopeContext.HasCapturedLocalVariables ? stackPositionToClose: (byte)0;
             if (hasElse)
             if (hasElse)
             {
             {
                 endJumpIndexList.Add(scopeContext.Function.Instructions.Length);
                 endJumpIndexList.Add(scopeContext.Function.Instructions.Length);
-                var a = scopeContext.HasCapturedLocalVariables ? scopeContext.StackPosition : (byte)0;
-                scopeContext.PushInstruction(Instruction.Jmp(a, 0), node.Position, true);
+                scopeContext.PushInstruction(Instruction.Jmp(stackPositionToClose, 0), node.Position, true);
             }
             }
             else
             else
             {
             {
-                scopeContext.TryPushCloseUpValue(scopeContext.StackPosition, node.Position);
+                scopeContext.TryPushCloseUpValue(stackPositionToClose, node.Position);
             }
             }
 
 
             scopeContext.Function.Instructions[ifPosition].SBx = scopeContext.Function.Instructions.Length - 1 - ifPosition;
             scopeContext.Function.Instructions[ifPosition].SBx = scopeContext.Function.Instructions.Length - 1 - ifPosition;
@@ -779,16 +782,16 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
                 childNode.Accept(this, scopeContext);
                 childNode.Accept(this, scopeContext);
             }
             }
 
 
+            stackPositionToClose =scopeContext.HasCapturedLocalVariables ? stackPositionToClose: (byte)0;
             // skip if node doesn't have else statements
             // skip if node doesn't have else statements
             if (hasElse)
             if (hasElse)
             {
             {
                 endJumpIndexList.Add(scopeContext.Function.Instructions.Length);
                 endJumpIndexList.Add(scopeContext.Function.Instructions.Length);
-                var a = scopeContext.HasCapturedLocalVariables ? scopeContext.StackPosition : (byte)0;
-                scopeContext.PushInstruction(Instruction.Jmp(a, 0), node.Position);
+                scopeContext.PushInstruction(Instruction.Jmp(stackPositionToClose, 0), node.Position);
             }
             }
             else
             else
             {
             {
-                scopeContext.TryPushCloseUpValue(scopeContext.StackPosition, node.Position);
+                scopeContext.TryPushCloseUpValue(stackPositionToClose, node.Position);
             }
             }
 
 
             scopeContext.Function.Instructions[elseifPosition].SBx = scopeContext.Function.Instructions.Length - 1 - elseifPosition;
             scopeContext.Function.Instructions[elseifPosition].SBx = scopeContext.Function.Instructions.Length - 1 - elseifPosition;
@@ -810,7 +813,7 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
         {
         {
             context.Function.Instructions[index].SBx = context.Function.Instructions.Length - 1 - index;
             context.Function.Instructions[index].SBx = context.Function.Instructions.Length - 1 - index;
         }
         }
-
+        
         return true;
         return true;
     }
     }
 
 
@@ -821,14 +824,14 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
         context.Function.LoopLevel++;
         context.Function.LoopLevel++;
 
 
         using var scopeContext = context.CreateChildScope();
         using var scopeContext = context.CreateChildScope();
-
+        var stackPosition = scopeContext.StackPosition;
         foreach (var childNode in node.Nodes)
         foreach (var childNode in node.Nodes)
         {
         {
             childNode.Accept(this, scopeContext);
             childNode.Accept(this, scopeContext);
         }
         }
 
 
         CompileConditionNode(node.ConditionNode, scopeContext, true);
         CompileConditionNode(node.ConditionNode, scopeContext, true);
-        var a = scopeContext.HasCapturedLocalVariables ? scopeContext.StackPosition : (byte)0;
+        var a = scopeContext.HasCapturedLocalVariables ? (byte)(stackPosition + 1) : (byte)0;
         scopeContext.PushInstruction(Instruction.Jmp(a, startIndex - scopeContext.Function.Instructions.Length - 1), node.Position);
         scopeContext.PushInstruction(Instruction.Jmp(a, startIndex - scopeContext.Function.Instructions.Length - 1), node.Position);
         scopeContext.TryPushCloseUpValue(scopeContext.StackPosition, node.Position);
         scopeContext.TryPushCloseUpValue(scopeContext.StackPosition, node.Position);
 
 
@@ -848,6 +851,7 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
         context.Function.LoopLevel++;
         context.Function.LoopLevel++;
 
 
         using var scopeContext = context.CreateChildScope();
         using var scopeContext = context.CreateChildScope();
+        var stackPosition = scopeContext.StackPosition;
 
 
         foreach (var childNode in node.Nodes)
         foreach (var childNode in node.Nodes)
         {
         {
@@ -860,7 +864,7 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
         scopeContext.Function.Instructions[conditionIndex].SBx = scopeContext.Function.Instructions.Length - 1 - conditionIndex;
         scopeContext.Function.Instructions[conditionIndex].SBx = scopeContext.Function.Instructions.Length - 1 - conditionIndex;
 
 
         CompileConditionNode(node.ConditionNode, scopeContext, false);
         CompileConditionNode(node.ConditionNode, scopeContext, false);
-        var a = scopeContext.HasCapturedLocalVariables ? scopeContext.StackPosition : (byte)0;
+        var a = scopeContext.HasCapturedLocalVariables ? (byte)(1 + stackPosition) : (byte)0;
         scopeContext.PushInstruction(Instruction.Jmp(a, conditionIndex - context.Function.Instructions.Length), node.Position);
         scopeContext.PushInstruction(Instruction.Jmp(a, conditionIndex - context.Function.Instructions.Length), node.Position);
         scopeContext.TryPushCloseUpValue(scopeContext.StackPosition, node.Position);
         scopeContext.TryPushCloseUpValue(scopeContext.StackPosition, node.Position);
 
 

+ 26 - 10
src/Lua/Runtime/Closure.cs

@@ -1,3 +1,4 @@
+using System.Runtime.CompilerServices;
 using Lua.Internal;
 using Lua.Internal;
 
 
 namespace Lua.Runtime;
 namespace Lua.Runtime;
@@ -16,30 +17,45 @@ public sealed class Closure : LuaFunction
         for (int i = 0; i < proto.UpValues.Length; i++)
         for (int i = 0; i < proto.UpValues.Length; i++)
         {
         {
             var description = proto.UpValues[i];
             var description = proto.UpValues[i];
-            var upValue = GetUpValueFromDescription(state, environment == null ? state.EnvUpValue : UpValue.Closed(environment), proto, description, 1);
+            var upValue = GetUpValueFromDescription(state, state.CurrentThread,environment == null ? state.EnvUpValue : UpValue.Closed(environment), description);
             upValues.Add(upValue);
             upValues.Add(upValue);
         }
         }
     }
     }
-
+    
     public Chunk Proto => proto;
     public Chunk Proto => proto;
     public ReadOnlySpan<UpValue> UpValues => upValues.AsSpan();
     public ReadOnlySpan<UpValue> UpValues => upValues.AsSpan();
 
 
-    static UpValue GetUpValueFromDescription(LuaState state, UpValue envUpValue, Chunk proto, UpValueInfo description, int depth)
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    internal LuaValue GetUpValue(int index)
+    {
+        return upValues[index].GetValue();
+    }
+
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    internal void SetUpValue(int index, LuaValue value)
     {
     {
+        upValues[index].SetValue(value);
+    }
+
+    static UpValue GetUpValueFromDescription(LuaState state, LuaThread thread,UpValue envUpValue,  UpValueInfo description)
+    {
+        
         if (description.IsInRegister)
         if (description.IsInRegister)
         {
         {
-            var thread = state.CurrentThread;
-            var callStack = thread.GetCallStackFrames();
-            var frame = callStack[^depth];
-            return state.GetOrAddUpValue(thread, frame.Base + description.Index);
+            return state.GetOrAddUpValue(thread, thread.GetCallStackFrames()[^1].Base+ description.Index);
+            
         }
         }
-        else if (description.Index == -1) // -1 is global environment
+        if (description.Index == -1) // -1 is global environment
         {
         {
             return envUpValue;
             return envUpValue;
         }
         }
-        else
         {
         {
-            return GetUpValueFromDescription(state, envUpValue, proto.Parent!, proto.Parent!.UpValues[description.Index], depth + 1);
+            if (thread.GetCallStackFrames()[^1].Function is Closure parentClosure)
+            {
+                return parentClosure.UpValues[description.Index];
+            }
+
+            throw new Exception();
         }
         }
     }
     }
 }
 }

+ 1 - 1
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -532,7 +532,7 @@ public static partial class LuaVirtualMachine
                         pc += instruction.SBx;
                         pc += instruction.SBx;
                         if (instruction.A != 0)
                         if (instruction.A != 0)
                         {
                         {
-                            state.CloseUpValues(thread, instruction.A - 1);
+                            state.CloseUpValues(thread, frame.Base + instruction.A - 1);
                         }
                         }
                         break;
                         break;
                     case OpCode.Eq:
                     case OpCode.Eq: