|
|
@@ -7,13 +7,6 @@ namespace Lua.CodeAnalysis.Compilation;
|
|
|
|
|
|
public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bool>
|
|
|
{
|
|
|
- enum CallFunctionType
|
|
|
- {
|
|
|
- Expression,
|
|
|
- Statement,
|
|
|
- TailCall
|
|
|
- }
|
|
|
-
|
|
|
public static readonly LuaCompiler Default = new();
|
|
|
|
|
|
/// <summary>
|
|
|
@@ -86,11 +79,15 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
|
|
|
// vararg
|
|
|
public bool VisitVariableArgumentsExpressionNode(VariableArgumentsExpressionNode node, ScopeCompilationContext context)
|
|
|
{
|
|
|
- // TODO: optimize
|
|
|
- context.PushInstruction(Instruction.VarArg(context.StackPosition, 0), node.Position, true);
|
|
|
+ CompileVariableArgumentsExpression(node, context);
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
+ void CompileVariableArgumentsExpression(VariableArgumentsExpressionNode node, ScopeCompilationContext context, int resultCount = -1)
|
|
|
+ {
|
|
|
+ context.PushInstruction(Instruction.VarArg(context.StackPosition, (ushort)(resultCount == -1 ? 0 : resultCount + 1)), node.Position, true);
|
|
|
+ }
|
|
|
+
|
|
|
// Unary/Binary expression
|
|
|
public bool VisitUnaryExpressionNode(UnaryExpressionNode node, ScopeCompilationContext context)
|
|
|
{
|
|
|
@@ -306,17 +303,17 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
|
|
|
|
|
|
public bool VisitCallTableMethodExpressionNode(CallTableMethodExpressionNode node, ScopeCompilationContext context)
|
|
|
{
|
|
|
- CompileTableMethod(node, context, CallFunctionType.Expression);
|
|
|
+ CompileTableMethod(node, context, false, 1);
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
public bool VisitCallTableMethodStatementNode(CallTableMethodStatementNode node, ScopeCompilationContext context)
|
|
|
{
|
|
|
- CompileTableMethod(node.Expression, context, CallFunctionType.Statement);
|
|
|
+ CompileTableMethod(node.Expression, context, false, 0);
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
- void CompileTableMethod(CallTableMethodExpressionNode node, ScopeCompilationContext context, CallFunctionType callType)
|
|
|
+ void CompileTableMethod(CallTableMethodExpressionNode node, ScopeCompilationContext context, bool isTailCall, int resultCount)
|
|
|
{
|
|
|
// load table
|
|
|
var tablePosition = context.StackPosition;
|
|
|
@@ -330,31 +327,25 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
|
|
|
context.StackPosition = (byte)(tablePosition + 2);
|
|
|
|
|
|
// load arguments
|
|
|
- foreach (var argument in node.ArgumentNodes)
|
|
|
- {
|
|
|
- argument.Accept(this, context);
|
|
|
- }
|
|
|
-
|
|
|
var b = node.ArgumentNodes.Length + 2;
|
|
|
if (node.ArgumentNodes.Length > 0 && !IsFixedNumberOfReturnValues(node.ArgumentNodes[^1]))
|
|
|
{
|
|
|
b = 0;
|
|
|
}
|
|
|
|
|
|
+ CompileExpressionList(node, node.ArgumentNodes, b - 1, context);
|
|
|
+
|
|
|
// push call interuction
|
|
|
- switch (callType)
|
|
|
+ if (isTailCall)
|
|
|
{
|
|
|
- case CallFunctionType.Expression:
|
|
|
- context.PushInstruction(Instruction.Call(tablePosition, (ushort)b, 0), node.Position);
|
|
|
- break;
|
|
|
- case CallFunctionType.Statement:
|
|
|
- context.PushInstruction(Instruction.Call(tablePosition, (ushort)b, 1), node.Position);
|
|
|
- break;
|
|
|
- case CallFunctionType.TailCall:
|
|
|
- context.PushInstruction(Instruction.TailCall(tablePosition, (ushort)b, 0), node.Position);
|
|
|
- break;
|
|
|
+ context.PushInstruction(Instruction.TailCall(tablePosition, (ushort)b, 0), node.Position);
|
|
|
+ context.StackPosition = tablePosition;
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ context.PushInstruction(Instruction.Call(tablePosition, (ushort)b, (ushort)(resultCount < 0 ? 0 : resultCount + 1)), node.Position);
|
|
|
+ context.StackPosition = (byte)(tablePosition + resultCount);
|
|
|
}
|
|
|
- context.StackPosition = (byte)(tablePosition + 1);
|
|
|
}
|
|
|
|
|
|
// return
|
|
|
@@ -369,12 +360,12 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
|
|
|
|
|
|
if (lastNode is CallFunctionExpressionNode call)
|
|
|
{
|
|
|
- CompileCallFunctionExpression(call, context, CallFunctionType.TailCall);
|
|
|
+ CompileCallFunctionExpression(call, context, true, -1);
|
|
|
return true;
|
|
|
}
|
|
|
else if (lastNode is CallTableMethodExpressionNode callMethod)
|
|
|
{
|
|
|
- CompileTableMethod(callMethod, context, CallFunctionType.TailCall);
|
|
|
+ CompileTableMethod(callMethod, context, true, -1);
|
|
|
return true;
|
|
|
}
|
|
|
}
|
|
|
@@ -383,10 +374,7 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
|
|
|
? (ushort)0
|
|
|
: (ushort)(node.Nodes.Length + 1);
|
|
|
|
|
|
- foreach (var childNode in node.Nodes)
|
|
|
- {
|
|
|
- childNode.Accept(this, context);
|
|
|
- }
|
|
|
+ CompileExpressionList(node, node.Nodes, b - 1, context);
|
|
|
|
|
|
context.PushInstruction(Instruction.Return((byte)(context.StackPosition - node.Nodes.Length), b), node.Position);
|
|
|
|
|
|
@@ -396,47 +384,27 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
|
|
|
// assignment
|
|
|
public bool VisitLocalAssignmentStatementNode(LocalAssignmentStatementNode node, ScopeCompilationContext context)
|
|
|
{
|
|
|
+ var startPosition = context.StackPosition;
|
|
|
+ CompileExpressionList(node, node.RightNodes, node.LeftNodes.Length, context);
|
|
|
+
|
|
|
for (int i = 0; i < node.Identifiers.Length; i++)
|
|
|
{
|
|
|
- if (node.RightNodes.Length > i)
|
|
|
- {
|
|
|
- // Load initial values for variables
|
|
|
- var expression = node.RightNodes[i];
|
|
|
- expression.Accept(this, context);
|
|
|
+ context.StackPosition = (byte)(startPosition + i + 1);
|
|
|
|
|
|
- var identifier = node.Identifiers[i];
|
|
|
+ var identifier = node.Identifiers[i];
|
|
|
|
|
|
- if (context.TryGetLocalVariableInThisScope(identifier.Name, out var variable))
|
|
|
- {
|
|
|
- // assign local variable
|
|
|
- context.PushInstruction(Instruction.Move(variable.RegisterIndex, (ushort)(context.StackPosition - 1)), node.Position, true);
|
|
|
- }
|
|
|
- else
|
|
|
- {
|
|
|
- // register local variable
|
|
|
- context.AddLocalVariable(identifier.Name, new()
|
|
|
- {
|
|
|
- RegisterIndex = (byte)(context.StackPosition - 1),
|
|
|
- });
|
|
|
- }
|
|
|
+ if (context.TryGetLocalVariableInThisScope(identifier.Name, out var variable))
|
|
|
+ {
|
|
|
+ // assign local variable
|
|
|
+ context.PushInstruction(Instruction.Move(variable.RegisterIndex, (ushort)(context.StackPosition - 1)), node.Position, true);
|
|
|
}
|
|
|
else
|
|
|
{
|
|
|
- // assigning nil to variables that do not have an initial value
|
|
|
- var varCount = node.Identifiers.Length - i;
|
|
|
- context.PushInstruction(Instruction.LoadNil(context.StackPosition, (ushort)varCount), node.Position);
|
|
|
- context.StackPosition = (byte)(context.StackPosition + varCount);
|
|
|
-
|
|
|
- // register local variables
|
|
|
- for (int n = 0; n < varCount; n++)
|
|
|
+ // register local variable
|
|
|
+ context.AddLocalVariable(identifier.Name, new()
|
|
|
{
|
|
|
- context.AddLocalVariable(node.Identifiers[i + n].Name, new()
|
|
|
- {
|
|
|
- RegisterIndex = (byte)(context.StackPosition + n - 1),
|
|
|
- });
|
|
|
- }
|
|
|
-
|
|
|
- break;
|
|
|
+ RegisterIndex = (byte)(context.StackPosition - 1),
|
|
|
+ });
|
|
|
}
|
|
|
}
|
|
|
return true;
|
|
|
@@ -446,13 +414,13 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
|
|
|
{
|
|
|
var startPosition = context.StackPosition;
|
|
|
|
|
|
+ CompileExpressionList(node, node.RightNodes, node.LeftNodes.Length, context);
|
|
|
+
|
|
|
for (int i = 0; i < node.LeftNodes.Length; i++)
|
|
|
{
|
|
|
- var expression = node.RightNodes[i];
|
|
|
+ context.StackPosition = (byte)(startPosition + i + 1);
|
|
|
var leftNode = node.LeftNodes[i];
|
|
|
|
|
|
- expression.Accept(this, context);
|
|
|
-
|
|
|
switch (leftNode)
|
|
|
{
|
|
|
case IdentifierNode identifier:
|
|
|
@@ -513,47 +481,42 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
|
|
|
// function call
|
|
|
public bool VisitCallFunctionStatementNode(CallFunctionStatementNode node, ScopeCompilationContext context)
|
|
|
{
|
|
|
- CompileCallFunctionExpression(node.Expression, context, CallFunctionType.Statement);
|
|
|
+ CompileCallFunctionExpression(node.Expression, context, false, 0);
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
public bool VisitCallFunctionExpressionNode(CallFunctionExpressionNode node, ScopeCompilationContext context)
|
|
|
{
|
|
|
- CompileCallFunctionExpression(node, context, CallFunctionType.Expression);
|
|
|
+ CompileCallFunctionExpression(node, context, false, 1);
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
- void CompileCallFunctionExpression(CallFunctionExpressionNode node, ScopeCompilationContext context, CallFunctionType callType)
|
|
|
+ void CompileCallFunctionExpression(CallFunctionExpressionNode node, ScopeCompilationContext context, bool isTailCall, int resultCount)
|
|
|
{
|
|
|
// get closure
|
|
|
var r = context.StackPosition;
|
|
|
node.FunctionNode.Accept(this, context);
|
|
|
|
|
|
- foreach (var argument in node.ArgumentNodes)
|
|
|
- {
|
|
|
- argument.Accept(this, context);
|
|
|
- }
|
|
|
-
|
|
|
+ // load arguments
|
|
|
var b = node.ArgumentNodes.Length + 1;
|
|
|
if (node.ArgumentNodes.Length > 0 && !IsFixedNumberOfReturnValues(node.ArgumentNodes[^1]))
|
|
|
{
|
|
|
b = 0;
|
|
|
}
|
|
|
|
|
|
+ CompileExpressionList(node, node.ArgumentNodes, b - 1, context);
|
|
|
+
|
|
|
// push call interuction
|
|
|
- switch (callType)
|
|
|
+ if (isTailCall)
|
|
|
{
|
|
|
- case CallFunctionType.Expression:
|
|
|
- context.PushInstruction(Instruction.Call(r, (ushort)b, 0), node.Position);
|
|
|
- break;
|
|
|
- case CallFunctionType.Statement:
|
|
|
- context.PushInstruction(Instruction.Call(r, (ushort)b, 1), node.Position);
|
|
|
- break;
|
|
|
- case CallFunctionType.TailCall:
|
|
|
- context.PushInstruction(Instruction.TailCall(r, (ushort)b, 0), node.Position);
|
|
|
- break;
|
|
|
+ context.PushInstruction(Instruction.TailCall(r, (ushort)b, 0), node.Position);
|
|
|
+ context.StackPosition = r;
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ context.PushInstruction(Instruction.Call(r, (ushort)b, (ushort)(resultCount == -1 ? 0 : resultCount + 1)), node.Position);
|
|
|
+ context.StackPosition = (byte)(r + resultCount);
|
|
|
}
|
|
|
- context.StackPosition = (byte)(r + 1);
|
|
|
}
|
|
|
|
|
|
// function declaration
|
|
|
@@ -1035,6 +998,46 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
|
|
|
context.PushInstruction(Instruction.Test((byte)(context.StackPosition - 1), falseIsSkip ? (byte)0 : (byte)1), node.Position);
|
|
|
}
|
|
|
|
|
|
+ void CompileExpressionList(SyntaxNode rootNode, ExpressionNode[] expressions, int minimumCount, ScopeCompilationContext context)
|
|
|
+ {
|
|
|
+ var isLastFunction = false;
|
|
|
+ for (int i = 0; i < expressions.Length; i++)
|
|
|
+ {
|
|
|
+ var expression = expressions[i];
|
|
|
+ var remaining = expressions.Length - i + 1;
|
|
|
+ var isLast = i == expressions.Length - 1;
|
|
|
+
|
|
|
+ if (expression is CallFunctionExpressionNode call)
|
|
|
+ {
|
|
|
+ CompileCallFunctionExpression(call, context, false, isLast ? remaining : 1);
|
|
|
+ isLastFunction = isLast;
|
|
|
+ }
|
|
|
+ else if (expression is CallTableMethodExpressionNode method)
|
|
|
+ {
|
|
|
+ CompileTableMethod(method, context, false, isLast ? remaining : 1);
|
|
|
+ isLastFunction = isLast;
|
|
|
+ }
|
|
|
+ else if (expression is VariableArgumentsExpressionNode varArg)
|
|
|
+ {
|
|
|
+ CompileVariableArgumentsExpression(varArg, context, isLast ? remaining : 1);
|
|
|
+ isLastFunction = isLast;
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ expression.Accept(this, context);
|
|
|
+ isLastFunction = false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // fill space with nil
|
|
|
+ var varCount = minimumCount - expressions.Length;
|
|
|
+ if (varCount > 0 && !isLastFunction)
|
|
|
+ {
|
|
|
+ context.PushInstruction(Instruction.LoadNil(context.StackPosition, (ushort)varCount), rootNode.Position);
|
|
|
+ context.StackPosition = (byte)(context.StackPosition + varCount);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
(byte b, byte c) GetBAndC(BinaryExpressionNode node, ScopeCompilationContext context)
|
|
|
{
|
|
|
byte b, c;
|