Browse Source

Imrove: setting the instruction parameters according to the number of return values

AnnulusGames 1 year ago
parent
commit
678a6fc61d

+ 92 - 89
src/Lua/CodeAnalysis/Compilation/LuaCompiler.cs

@@ -7,13 +7,6 @@ namespace Lua.CodeAnalysis.Compilation;
 
 
 public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bool>
 public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bool>
 {
 {
-    enum CallFunctionType
-    {
-        Expression,
-        Statement,
-        TailCall
-    }
-
     public static readonly LuaCompiler Default = new();
     public static readonly LuaCompiler Default = new();
 
 
     /// <summary>
     /// <summary>
@@ -86,11 +79,15 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
     // vararg
     // vararg
     public bool VisitVariableArgumentsExpressionNode(VariableArgumentsExpressionNode node, ScopeCompilationContext context)
     public bool VisitVariableArgumentsExpressionNode(VariableArgumentsExpressionNode node, ScopeCompilationContext context)
     {
     {
-        // TODO: optimize
-        context.PushInstruction(Instruction.VarArg(context.StackPosition, 0), node.Position, true);
+        CompileVariableArgumentsExpression(node, context);
         return true;
         return true;
     }
     }
 
 
+    void CompileVariableArgumentsExpression(VariableArgumentsExpressionNode node, ScopeCompilationContext context, int resultCount = -1)
+    {
+        context.PushInstruction(Instruction.VarArg(context.StackPosition, (ushort)(resultCount == -1 ? 0 : resultCount + 1)), node.Position, true);
+    }
+
     // Unary/Binary expression
     // Unary/Binary expression
     public bool VisitUnaryExpressionNode(UnaryExpressionNode node, ScopeCompilationContext context)
     public bool VisitUnaryExpressionNode(UnaryExpressionNode node, ScopeCompilationContext context)
     {
     {
@@ -306,17 +303,17 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
 
 
     public bool VisitCallTableMethodExpressionNode(CallTableMethodExpressionNode node, ScopeCompilationContext context)
     public bool VisitCallTableMethodExpressionNode(CallTableMethodExpressionNode node, ScopeCompilationContext context)
     {
     {
-        CompileTableMethod(node, context, CallFunctionType.Expression);
+        CompileTableMethod(node, context, false, 1);
         return true;
         return true;
     }
     }
 
 
     public bool VisitCallTableMethodStatementNode(CallTableMethodStatementNode node, ScopeCompilationContext context)
     public bool VisitCallTableMethodStatementNode(CallTableMethodStatementNode node, ScopeCompilationContext context)
     {
     {
-        CompileTableMethod(node.Expression, context, CallFunctionType.Statement);
+        CompileTableMethod(node.Expression, context, false, 0);
         return true;
         return true;
     }
     }
 
 
-    void CompileTableMethod(CallTableMethodExpressionNode node, ScopeCompilationContext context, CallFunctionType callType)
+    void CompileTableMethod(CallTableMethodExpressionNode node, ScopeCompilationContext context, bool isTailCall, int resultCount)
     {
     {
         // load table
         // load table
         var tablePosition = context.StackPosition;
         var tablePosition = context.StackPosition;
@@ -330,31 +327,25 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
         context.StackPosition = (byte)(tablePosition + 2);
         context.StackPosition = (byte)(tablePosition + 2);
 
 
         // load arguments
         // load arguments
-        foreach (var argument in node.ArgumentNodes)
-        {
-            argument.Accept(this, context);
-        }
-
         var b = node.ArgumentNodes.Length + 2;
         var b = node.ArgumentNodes.Length + 2;
         if (node.ArgumentNodes.Length > 0 && !IsFixedNumberOfReturnValues(node.ArgumentNodes[^1]))
         if (node.ArgumentNodes.Length > 0 && !IsFixedNumberOfReturnValues(node.ArgumentNodes[^1]))
         {
         {
             b = 0;
             b = 0;
         }
         }
 
 
+        CompileExpressionList(node, node.ArgumentNodes, b - 1, context);
+
         // push call interuction
         // push call interuction
-        switch (callType)
+        if (isTailCall)
         {
         {
-            case CallFunctionType.Expression:
-                context.PushInstruction(Instruction.Call(tablePosition, (ushort)b, 0), node.Position);
-                break;
-            case CallFunctionType.Statement:
-                context.PushInstruction(Instruction.Call(tablePosition, (ushort)b, 1), node.Position);
-                break;
-            case CallFunctionType.TailCall:
-                context.PushInstruction(Instruction.TailCall(tablePosition, (ushort)b, 0), node.Position);
-                break;
+            context.PushInstruction(Instruction.TailCall(tablePosition, (ushort)b, 0), node.Position);
+            context.StackPosition = tablePosition;
+        }
+        else
+        {
+            context.PushInstruction(Instruction.Call(tablePosition, (ushort)b, (ushort)(resultCount < 0 ? 0 : resultCount + 1)), node.Position);
+            context.StackPosition = (byte)(tablePosition + resultCount);
         }
         }
-        context.StackPosition = (byte)(tablePosition + 1);
     }
     }
 
 
     // return
     // return
@@ -369,12 +360,12 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
 
 
             if (lastNode is CallFunctionExpressionNode call)
             if (lastNode is CallFunctionExpressionNode call)
             {
             {
-                CompileCallFunctionExpression(call, context, CallFunctionType.TailCall);
+                CompileCallFunctionExpression(call, context, true, -1);
                 return true;
                 return true;
             }
             }
             else if (lastNode is CallTableMethodExpressionNode callMethod)
             else if (lastNode is CallTableMethodExpressionNode callMethod)
             {
             {
-                CompileTableMethod(callMethod, context, CallFunctionType.TailCall);
+                CompileTableMethod(callMethod, context, true, -1);
                 return true;
                 return true;
             }
             }
         }
         }
@@ -383,10 +374,7 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
             ? (ushort)0
             ? (ushort)0
             : (ushort)(node.Nodes.Length + 1);
             : (ushort)(node.Nodes.Length + 1);
 
 
-        foreach (var childNode in node.Nodes)
-        {
-            childNode.Accept(this, context);
-        }
+        CompileExpressionList(node, node.Nodes, b - 1, context);
 
 
         context.PushInstruction(Instruction.Return((byte)(context.StackPosition - node.Nodes.Length), b), node.Position);
         context.PushInstruction(Instruction.Return((byte)(context.StackPosition - node.Nodes.Length), b), node.Position);
 
 
@@ -396,47 +384,27 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
     // assignment
     // assignment
     public bool VisitLocalAssignmentStatementNode(LocalAssignmentStatementNode node, ScopeCompilationContext context)
     public bool VisitLocalAssignmentStatementNode(LocalAssignmentStatementNode node, ScopeCompilationContext context)
     {
     {
+        var startPosition = context.StackPosition;
+        CompileExpressionList(node, node.RightNodes, node.LeftNodes.Length, context);
+
         for (int i = 0; i < node.Identifiers.Length; i++)
         for (int i = 0; i < node.Identifiers.Length; i++)
         {
         {
-            if (node.RightNodes.Length > i)
-            {
-                // Load initial values ​​for variables
-                var expression = node.RightNodes[i];
-                expression.Accept(this, context);
+            context.StackPosition = (byte)(startPosition + i + 1);
 
 
-                var identifier = node.Identifiers[i];
+            var identifier = node.Identifiers[i];
 
 
-                if (context.TryGetLocalVariableInThisScope(identifier.Name, out var variable))
-                {
-                    // assign local variable
-                    context.PushInstruction(Instruction.Move(variable.RegisterIndex, (ushort)(context.StackPosition - 1)), node.Position, true);
-                }
-                else
-                {
-                    // register local variable
-                    context.AddLocalVariable(identifier.Name, new()
-                    {
-                        RegisterIndex = (byte)(context.StackPosition - 1),
-                    });
-                }
+            if (context.TryGetLocalVariableInThisScope(identifier.Name, out var variable))
+            {
+                // assign local variable
+                context.PushInstruction(Instruction.Move(variable.RegisterIndex, (ushort)(context.StackPosition - 1)), node.Position, true);
             }
             }
             else
             else
             {
             {
-                // assigning nil to variables that do not have an initial value
-                var varCount = node.Identifiers.Length - i;
-                context.PushInstruction(Instruction.LoadNil(context.StackPosition, (ushort)varCount), node.Position);
-                context.StackPosition = (byte)(context.StackPosition + varCount);
-
-                // register local variables
-                for (int n = 0; n < varCount; n++)
+                // register local variable
+                context.AddLocalVariable(identifier.Name, new()
                 {
                 {
-                    context.AddLocalVariable(node.Identifiers[i + n].Name, new()
-                    {
-                        RegisterIndex = (byte)(context.StackPosition + n - 1),
-                    });
-                }
-
-                break;
+                    RegisterIndex = (byte)(context.StackPosition - 1),
+                });
             }
             }
         }
         }
         return true;
         return true;
@@ -446,13 +414,13 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
     {
     {
         var startPosition = context.StackPosition;
         var startPosition = context.StackPosition;
 
 
+        CompileExpressionList(node, node.RightNodes, node.LeftNodes.Length, context);
+
         for (int i = 0; i < node.LeftNodes.Length; i++)
         for (int i = 0; i < node.LeftNodes.Length; i++)
         {
         {
-            var expression = node.RightNodes[i];
+            context.StackPosition = (byte)(startPosition + i + 1);
             var leftNode = node.LeftNodes[i];
             var leftNode = node.LeftNodes[i];
 
 
-            expression.Accept(this, context);
-
             switch (leftNode)
             switch (leftNode)
             {
             {
                 case IdentifierNode identifier:
                 case IdentifierNode identifier:
@@ -513,47 +481,42 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
     // function call
     // function call
     public bool VisitCallFunctionStatementNode(CallFunctionStatementNode node, ScopeCompilationContext context)
     public bool VisitCallFunctionStatementNode(CallFunctionStatementNode node, ScopeCompilationContext context)
     {
     {
-        CompileCallFunctionExpression(node.Expression, context, CallFunctionType.Statement);
+        CompileCallFunctionExpression(node.Expression, context, false, 0);
         return true;
         return true;
     }
     }
 
 
     public bool VisitCallFunctionExpressionNode(CallFunctionExpressionNode node, ScopeCompilationContext context)
     public bool VisitCallFunctionExpressionNode(CallFunctionExpressionNode node, ScopeCompilationContext context)
     {
     {
-        CompileCallFunctionExpression(node, context, CallFunctionType.Expression);
+        CompileCallFunctionExpression(node, context, false, 1);
         return true;
         return true;
     }
     }
 
 
-    void CompileCallFunctionExpression(CallFunctionExpressionNode node, ScopeCompilationContext context, CallFunctionType callType)
+    void CompileCallFunctionExpression(CallFunctionExpressionNode node, ScopeCompilationContext context, bool isTailCall, int resultCount)
     {
     {
         // get closure
         // get closure
         var r = context.StackPosition;
         var r = context.StackPosition;
         node.FunctionNode.Accept(this, context);
         node.FunctionNode.Accept(this, context);
 
 
-        foreach (var argument in node.ArgumentNodes)
-        {
-            argument.Accept(this, context);
-        }
-
+        // load arguments
         var b = node.ArgumentNodes.Length + 1;
         var b = node.ArgumentNodes.Length + 1;
         if (node.ArgumentNodes.Length > 0 && !IsFixedNumberOfReturnValues(node.ArgumentNodes[^1]))
         if (node.ArgumentNodes.Length > 0 && !IsFixedNumberOfReturnValues(node.ArgumentNodes[^1]))
         {
         {
             b = 0;
             b = 0;
         }
         }
 
 
+        CompileExpressionList(node, node.ArgumentNodes, b - 1, context);
+
         // push call interuction
         // push call interuction
-        switch (callType)
+        if (isTailCall)
         {
         {
-            case CallFunctionType.Expression:
-                context.PushInstruction(Instruction.Call(r, (ushort)b, 0), node.Position);
-                break;
-            case CallFunctionType.Statement:
-                context.PushInstruction(Instruction.Call(r, (ushort)b, 1), node.Position);
-                break;
-            case CallFunctionType.TailCall:
-                context.PushInstruction(Instruction.TailCall(r, (ushort)b, 0), node.Position);
-                break;
+            context.PushInstruction(Instruction.TailCall(r, (ushort)b, 0), node.Position);
+            context.StackPosition = r;
+        }
+        else
+        {
+            context.PushInstruction(Instruction.Call(r, (ushort)b, (ushort)(resultCount == -1 ? 0 : resultCount + 1)), node.Position);
+            context.StackPosition = (byte)(r + resultCount);
         }
         }
-        context.StackPosition = (byte)(r + 1);
     }
     }
 
 
     // function declaration
     // function declaration
@@ -1035,6 +998,46 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
         context.PushInstruction(Instruction.Test((byte)(context.StackPosition - 1), falseIsSkip ? (byte)0 : (byte)1), node.Position);
         context.PushInstruction(Instruction.Test((byte)(context.StackPosition - 1), falseIsSkip ? (byte)0 : (byte)1), node.Position);
     }
     }
 
 
+    void CompileExpressionList(SyntaxNode rootNode, ExpressionNode[] expressions, int minimumCount, ScopeCompilationContext context)
+    {
+        var isLastFunction = false;
+        for (int i = 0; i < expressions.Length; i++)
+        {
+            var expression = expressions[i];
+            var remaining = expressions.Length - i + 1;
+            var isLast = i == expressions.Length - 1;
+
+            if (expression is CallFunctionExpressionNode call)
+            {
+                CompileCallFunctionExpression(call, context, false, isLast ? remaining : 1);
+                isLastFunction = isLast;
+            }
+            else if (expression is CallTableMethodExpressionNode method)
+            {
+                CompileTableMethod(method, context, false, isLast ? remaining : 1);
+                isLastFunction = isLast;
+            }
+            else if (expression is VariableArgumentsExpressionNode varArg)
+            {
+                CompileVariableArgumentsExpression(varArg, context, isLast ? remaining : 1);
+                isLastFunction = isLast;
+            }
+            else
+            {
+                expression.Accept(this, context);
+                isLastFunction = false;
+            }
+        }
+
+        // fill space with nil
+        var varCount = minimumCount - expressions.Length;
+        if (varCount > 0 && !isLastFunction)
+        {
+            context.PushInstruction(Instruction.LoadNil(context.StackPosition, (ushort)varCount), rootNode.Position);
+            context.StackPosition = (byte)(context.StackPosition + varCount);
+        }
+    }
+
     (byte b, byte c) GetBAndC(BinaryExpressionNode node, ScopeCompilationContext context)
     (byte b, byte c) GetBAndC(BinaryExpressionNode node, ScopeCompilationContext context)
     {
     {
         byte b, c;
         byte b, c;

+ 1 - 1
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -855,7 +855,7 @@ public static partial class LuaVirtualMachine
                         stack.EnsureCapacity(RA + count);
                         stack.EnsureCapacity(RA + count);
                         for (int i = 0; i < count; i++)
                         for (int i = 0; i < count; i++)
                         {
                         {
-                            stack.UnsafeGet(RA + i) = stack.UnsafeGet(frame.Base - (count - i));
+                            stack.UnsafeGet(RA + i) = stack.UnsafeGet(frame.Base - (frame.VariableArgumentCount - i));
                         }
                         }
                         stack.NotifyTop(RA + count);
                         stack.NotifyTop(RA + count);
                     }
                     }