Browse Source

Fix: `or` doesn't work well

Akeit0 9 months ago
parent
commit
b6567ced0c

+ 33 - 24
src/Lua/CodeAnalysis/Compilation/FunctionCompilationContext.cs

@@ -119,7 +119,7 @@ public class FunctionCompilationContext : IDisposable
     /// Push or merge the new instruction.
     /// </summary>
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
-    public void PushOrMergeInstruction(int lastLocal, in Instruction instruction, in SourcePosition position, ref bool incrementStackPosition)
+    internal void PushOrMergeInstruction(in Instruction instruction, in SourcePosition position, ref bool incrementStackPosition)
     {
         if (instructions.Length == 0)
         {
@@ -127,27 +127,41 @@ public class FunctionCompilationContext : IDisposable
             instructionPositions.Add(position);
             return;
         }
+
+        var activeLocals = Scope.ActiveLocalVariables;
+
         ref var lastInstruction = ref instructions.AsSpan()[^1];
         var opcode = instruction.OpCode;
         switch (opcode)
         {
             case OpCode.Move:
-                // last A is not local variable
-                if (lastInstruction.A != lastLocal &&
-                    // available to merge
-                    lastInstruction.A == instruction.B &&
-                    // not already merged
-                    lastInstruction.A != lastInstruction.B)
+
+                if (
+                    // available to merge and  last A is not local variable
+                    lastInstruction.A == instruction.B && !activeLocals[lastInstruction.A])
                 {
                     switch (lastInstruction.OpCode)
                     {
+                        case OpCode.LoadK:
+                        case OpCode.LoadBool when lastInstruction.C == 0:
+                        case OpCode.LoadNil when lastInstruction.B == 0:
+                        case OpCode.GetUpVal:
+                        case OpCode.GetTabUp:
                         case OpCode.GetTable:
+                        case OpCode.SetTabUp:
+                        case OpCode.SetUpVal:
+                        case OpCode.SetTable:
+                        case OpCode.NewTable:
+                        case OpCode.Self:
                         case OpCode.Add:
                         case OpCode.Sub:
                         case OpCode.Mul:
                         case OpCode.Div:
                         case OpCode.Mod:
                         case OpCode.Pow:
+                        case OpCode.Unm:
+                        case OpCode.Not:
+                        case OpCode.Len:
                         case OpCode.Concat:
                             {
                                 lastInstruction.A = instruction.A;
@@ -156,11 +170,12 @@ public class FunctionCompilationContext : IDisposable
                             }
                     }
                 }
+
                 break;
             case OpCode.GetTable:
                 {
                     // Merge MOVE GetTable
-                    if (lastInstruction.OpCode == OpCode.Move && lastLocal != lastInstruction.A)
+                    if (lastInstruction.OpCode == OpCode.Move && !activeLocals[lastInstruction.A])
                     {
                         if (lastInstruction.A == instruction.B)
                         {
@@ -169,14 +184,14 @@ public class FunctionCompilationContext : IDisposable
                             incrementStackPosition = false;
                             return;
                         }
-
                     }
+
                     break;
                 }
             case OpCode.SetTable:
                 {
                     // Merge MOVE SETTABLE
-                    if (lastInstruction.OpCode == OpCode.Move && lastLocal != lastInstruction.A)
+                    if (lastInstruction.OpCode == OpCode.Move && !activeLocals[lastInstruction.A])
                     {
                         var lastB = lastInstruction.B;
                         var lastA = lastInstruction.A;
@@ -187,7 +202,7 @@ public class FunctionCompilationContext : IDisposable
                             {
                                 ref var last2Instruction = ref instructions.AsSpan()[^2];
                                 var last2A = last2Instruction.A;
-                                if (last2Instruction.OpCode == OpCode.Move && lastLocal != last2A && instruction.C == last2A)
+                                if (last2Instruction.OpCode == OpCode.Move && !activeLocals[last2A] && instruction.C == last2A)
                                 {
                                     last2Instruction = Instruction.SetTable((byte)(lastB), instruction.B, last2Instruction.B);
                                     instructions.RemoveAtSwapback(instructions.Length - 1);
@@ -197,6 +212,7 @@ public class FunctionCompilationContext : IDisposable
                                     return;
                                 }
                             }
+
                             lastInstruction = Instruction.SetTable((byte)(lastB), instruction.B, instruction.C);
                             instructionPositions[^1] = position;
                             incrementStackPosition = false;
@@ -217,9 +233,8 @@ public class FunctionCompilationContext : IDisposable
                         var last2OpCode = last2Instruction.OpCode;
                         if (last2OpCode is OpCode.LoadK or OpCode.Move)
                         {
-
                             var last2A = last2Instruction.A;
-                            if (last2A != lastLocal && instruction.C == last2A)
+                            if (!activeLocals[last2A] && instruction.C == last2A)
                             {
                                 var c = last2OpCode == OpCode.LoadK ? last2Instruction.Bx + 256 : last2Instruction.B;
                                 last2Instruction = lastInstruction;
@@ -231,27 +246,20 @@ public class FunctionCompilationContext : IDisposable
                             }
                         }
                     }
+
                     break;
                 }
             case OpCode.Unm:
             case OpCode.Not:
             case OpCode.Len:
-                if (lastInstruction.OpCode == OpCode.Move && lastLocal != lastInstruction.A && lastInstruction.A == instruction.B)
+                if (lastInstruction.OpCode == OpCode.Move && !activeLocals[lastInstruction.A] && lastInstruction.A == instruction.B)
                 {
-                    lastInstruction = instruction with { B = lastInstruction.B }; ;
-                    instructionPositions[^1] = position;
-                    incrementStackPosition = false;
-                    return;
-                }
-                break;
-            case OpCode.Return:
-                if (lastInstruction.OpCode == OpCode.Move && instruction.B == 2 && lastInstruction.B < 256)
-                {
-                    lastInstruction = instruction with { A = (byte)lastInstruction.B };
+                    lastInstruction = instruction with { B = lastInstruction.B };
                     instructionPositions[^1] = position;
                     incrementStackPosition = false;
                     return;
                 }
+
                 break;
         }
 
@@ -378,6 +386,7 @@ public class FunctionCompilationContext : IDisposable
             {
                 instruction.A = startPosition;
             }
+
             instruction.SBx = endPosition - description.Index;
         }
 

+ 13 - 1
src/Lua/CodeAnalysis/Compilation/LuaCompiler.cs

@@ -144,7 +144,13 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
                 a = context.StackTopPosition;
             }
 
-            context.PushInstruction(Instruction.Test(a, (byte)(node.OperatorType is BinaryOperator.And ? 0 : 1)), node.Position);
+            context.PushInstruction(Instruction.Test(a, 0), node.Position);
+            if (node.OperatorType is BinaryOperator.Or)
+            {
+                context.PushInstruction(Instruction.Jmp(0, 2), node.Position);
+                context.PushInstruction(Instruction.Move(r, a), node.Position);
+            }
+
             var testJmpIndex = context.Function.Instructions.Length;
             context.PushInstruction(Instruction.Jmp(0, 0), node.Position);
 
@@ -469,6 +475,7 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
                 });
             }
         }
+
         return true;
     }
 
@@ -1102,6 +1109,7 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
                 {
                     value = stringLiteral.Text.ToString();
                 }
+
                 return true;
             case UnaryExpressionNode unaryExpression:
                 if (TryGetConstant(unaryExpression.Node, context, out var unaryNodeValue))
@@ -1114,6 +1122,7 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
                                 value = -d1;
                                 return true;
                             }
+
                             break;
                         case UnaryOperator.Not:
                             if (unaryNodeValue.TryRead<bool>(out var b))
@@ -1121,9 +1130,11 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
                                 value = !b;
                                 return true;
                             }
+
                             break;
                     }
                 }
+
                 break;
             case BinaryExpressionNode binaryExpression:
                 if (TryGetConstant(binaryExpression.LeftNode, context, out var leftValue) &&
@@ -1169,6 +1180,7 @@ public sealed class LuaCompiler : ISyntaxNodeVisitor<ScopeCompilationContext, bo
                             break;
                     }
                 }
+
                 break;
         }
 

+ 4 - 4
src/Lua/CodeAnalysis/Compilation/ScopeCompilationContext.cs

@@ -33,7 +33,7 @@ public class ScopeCompilationContext : IDisposable
     readonly Dictionary<ReadOnlyMemory<char>, LocalVariableDescription> localVariables = new(256, Utf16StringMemoryComparer.Default);
     readonly Dictionary<ReadOnlyMemory<char>, LabelDescription> labels = new(32, Utf16StringMemoryComparer.Default);
 
-    byte lastLocalVariableIndex;
+    internal BitFlags256 ActiveLocalVariables = default;
 
     public byte StackStartPosition { get; private set; }
     public byte StackPosition { get; set; }
@@ -74,7 +74,7 @@ public class ScopeCompilationContext : IDisposable
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
     public void PushInstruction(in Instruction instruction, SourcePosition position, bool incrementStackPosition = false)
     {
-        Function.PushOrMergeInstruction(lastLocalVariableIndex, instruction, position, ref incrementStackPosition);
+        Function.PushOrMergeInstruction(instruction, position, ref incrementStackPosition);
         if (incrementStackPosition)
         {
             StackPosition++;
@@ -98,7 +98,7 @@ public class ScopeCompilationContext : IDisposable
     public void AddLocalVariable(ReadOnlyMemory<char> name, LocalVariableDescription description, bool markAsLastLocalVariable = true)
     {
         localVariables[name] = description;
-        lastLocalVariableIndex = description.RegisterIndex;
+        ActiveLocalVariables.Set(description.RegisterIndex);
     }
 
 
@@ -165,7 +165,7 @@ public class ScopeCompilationContext : IDisposable
         HasCapturedLocalVariables = false;
         localVariables.Clear();
         labels.Clear();
-        lastLocalVariableIndex = 0;
+        ActiveLocalVariables = default;
     }
 
     /// <summary>

+ 23 - 0
src/Lua/Internal/BitFlags256.cs

@@ -0,0 +1,23 @@
+namespace Lua.Internal;
+
+internal unsafe struct BitFlags256
+{
+    internal fixed long Data[4];
+    
+    public bool this[int index]
+    {
+        get => (Data[index >> 6] & (1L << (index & 63))) != 0;
+        set
+        {
+            if (value)
+            {
+                Data[index >> 6] |= 1L << (index & 63);
+            }
+            else
+            {
+                Data[index >> 6] &= ~(1L << (index & 63));
+            }
+        }
+    }
+    public void Set(int index) => Data[index >> 6] |= 1L << (index & 63);
+}

+ 22 - 0
tests/Lua.Tests/ConditionalsTests.cs

@@ -0,0 +1,22 @@
+namespace Lua.Tests;
+
+public class ConditionalsTests
+{
+    [Test]
+    public async Task Test_Clamp()
+    {
+        var source = @"
+function clamp(x, min, max)
+    return x < min and min or (x > max and max or x)
+end
+
+return clamp(0, 1, 25), clamp(10, 1, 25), clamp(30, 1, 25)
+";
+        var result = await LuaState.Create().DoStringAsync(source);
+
+        Assert.That(result, Has.Length.EqualTo(3));
+        Assert.That(result[0], Is.EqualTo(new LuaValue(1)));
+        Assert.That(result[1], Is.EqualTo(new LuaValue(10)));
+        Assert.That(result[2], Is.EqualTo(new LuaValue(25)));
+    }
+}