Browse Source

Merge pull request #72 from Akeit0/fix-metatable-index

Fix: metatable index/newindex with table
Annulus Games 11 months ago
parent
commit
a26a4abbc4

+ 20 - 1
src/Lua/LuaTable.cs

@@ -110,6 +110,25 @@ public sealed class LuaTable
         return dictionary.TryGetValue(key, out value) && value.Type is not LuaValueType.Nil;
     }
 
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    internal ref LuaValue FindValue(LuaValue key)
+    {
+        if (key.Type is LuaValueType.Nil)
+        {
+            ThrowIndexIsNil();
+        }
+
+        if (TryGetInteger(key, out var index))
+        {
+            if (index > 0 && index <= array.Length)
+            {
+                return ref array[index - 1];
+            }
+        }
+
+        return ref dictionary.FindValue(key, out _);
+    }
+
     public bool ContainsKey(LuaValue key)
     {
         if (key.Type is LuaValueType.Nil)
@@ -194,7 +213,7 @@ public sealed class LuaTable
         }
         else
         {
-            if(dictionary.TryGetNext(key, out pair))
+            if (dictionary.TryGetNext(key, out pair))
             {
                 return true;
             }

+ 10 - 4
src/Lua/Runtime/Closure.cs

@@ -31,6 +31,12 @@ public sealed class Closure : LuaFunction
         return upValues[index].GetValue();
     }
 
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    internal ref readonly LuaValue GetUpValueRef(int index)
+    {
+        return ref upValues[index].GetValueRef();
+    }
+
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
     internal void SetUpValue(int index, LuaValue value)
     {
@@ -43,17 +49,17 @@ public sealed class Closure : LuaFunction
         {
             return state.GetOrAddUpValue(thread, thread.GetCallStackFrames()[^1].Base + description.Index);
         }
-        
+
         if (description.Index == -1) // -1 is global environment
         {
             return envUpValue;
         }
-        
+
         if (thread.GetCallStackFrames()[^1].Function is Closure parentClosure)
         {
-             return parentClosure.UpValues[description.Index];
+            return parentClosure.UpValues[description.Index];
         }
-        
+
         throw new Exception();
     }
 }

+ 155 - 128
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -92,20 +92,19 @@ public static partial class LuaVirtualMachine
             switch (opCode)
             {
                 case OpCode.Call:
-                {
-                    var c = callInstruction.C;
-                    if (c != 0)
                     {
-                        targetCount = c - 1;
-                    }
+                        var c = callInstruction.C;
+                        if (c != 0)
+                        {
+                            targetCount = c - 1;
+                        }
 
-                    break;
-                }
+                        break;
+                    }
                 case OpCode.TForCall:
                     target += 3;
                     targetCount = callInstruction.C;
                     break;
-
                 case OpCode.Self:
                     Stack.Get(target) = result.Length == 0 ? LuaValue.Nil : result[0];
                     Thread.PopCallStackFrameUnsafe(target + 2);
@@ -266,7 +265,7 @@ public static partial class LuaVirtualMachine
 
         try
         {
-            // This is a label to restart the execution when new function is called or restarted
+        // This is a label to restart the execution when new function is called or restarted
         Restart:
             ref var instructionsHead = ref context.Chunk.Instructions[0];
             var frameBase = context.FrameBase;
@@ -308,42 +307,19 @@ public static partial class LuaVirtualMachine
                         stack.GetWithNotifyTop(instruction.A + frameBase) = context.Closure.GetUpValue(instruction.B);
                         continue;
                     case OpCode.GetTabUp:
+                    case OpCode.GetTable:
                         instruction = instructionRef;
                         stackHead = ref stack.FastGet(frameBase);
                         ref readonly var vc = ref RKC(ref stackHead, ref constHead, instruction);
-                        var table = context.Closure.GetUpValue(instruction.B);
-
-                        if (table.TryReadTable(out var luaTable) && luaTable.TryGetValue(vc, out var resultValue))
-                        {
-                            stack.GetWithNotifyTop(instruction.A + frameBase) = resultValue;
-                            continue;
-                        }
-
-                        if (TryGetMetaTableValue(table, vc, ref context, out var doRestart))
+                        ref readonly var vb = ref (instruction.OpCode == OpCode.GetTable ? ref Unsafe.Add(ref stackHead, instruction.UIntB) : ref context.Closure.GetUpValueRef(instruction.B));
+                        var doRestart = false;
+                        if (vb.TryReadTable(out var luaTable) && luaTable.TryGetValue(vc, out var resultValue) || GetTableValueSlowPath(vb, vc, ref context, out resultValue, out doRestart))
                         {
                             if (doRestart) goto Restart;
-                            continue;
-                        }
-
-                        postOperation = PostOperationType.SetResult;
-                        return true;
-                    case OpCode.GetTable:
-                        instruction = instructionRef;
-                        stackHead = ref stack.FastGet(frameBase);
-                        ref readonly var vb = ref Unsafe.Add(ref stackHead, instruction.UIntB);
-                        vc = ref RKC(ref stackHead, ref constHead, instruction);
-                        if (vb.TryReadTable(out luaTable) && luaTable.TryGetValue(vc, out resultValue))
-                        {
                             stack.GetWithNotifyTop(instruction.A + frameBase) = resultValue;
                             continue;
                         }
 
-                        if (TryGetMetaTableValue(vb, vc, ref context, out doRestart))
-                        {
-                            if (doRestart) goto Restart;
-                            continue;
-                        }
-
                         postOperation = PostOperationType.SetResult;
                         return true;
                     case OpCode.SetTabUp:
@@ -359,16 +335,20 @@ public static partial class LuaVirtualMachine
                             }
                         }
 
-                        table = context.Closure.GetUpValue(instruction.A);
+                        var table = context.Closure.GetUpValue(instruction.A);
 
                         if (table.TryReadTable(out luaTable))
                         {
-                            luaTable[vb] = RKC(ref stackHead, ref constHead, instruction);
-                            continue;
+                            ref var valueRef = ref luaTable.FindValue(vb);
+                            if (!Unsafe.IsNullRef(ref valueRef) && valueRef.Type != LuaValueType.Nil)
+                            {
+                                valueRef = RKC(ref stackHead, ref constHead, instruction);
+                                continue;
+                            }
                         }
 
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
-                        if (TrySetMetaTableValue(table, vb, vc, ref context, out doRestart))
+                        if (SetTableValueSlowPath(table, vb, vc, ref context, out doRestart))
                         {
                             if (doRestart) goto Restart;
                             continue;
@@ -399,15 +379,16 @@ public static partial class LuaVirtualMachine
 
                         if (table.TryReadTable(out luaTable))
                         {
-                            if (luaTable.Metatable == null || !luaTable.Metatable.ContainsKey(Metamethods.NewIndex) || luaTable.ContainsKey(vb))
+                            ref var valueRef = ref luaTable.FindValue(vb);
+                            if (!Unsafe.IsNullRef(ref valueRef) && valueRef.Type != LuaValueType.Nil)
                             {
-                                luaTable[vb] = RKC(ref stackHead, ref constHead, instruction);
+                                valueRef = RKC(ref stackHead, ref constHead, instruction);
                                 continue;
                             }
                         }
 
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
-                        if (TrySetMetaTableValue(table, vb, vc, ref context, out doRestart))
+                        if (SetTableValueSlowPath(table, vb, vc, ref context, out doRestart))
                         {
                             if (doRestart) goto Restart;
                             continue;
@@ -425,21 +406,17 @@ public static partial class LuaVirtualMachine
                         stackHead = ref stack.FastGet(frameBase);
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
                         table = Unsafe.Add(ref stackHead, instruction.UIntB);
-                        if (table.TryReadTable(out luaTable) && luaTable.TryGetValue(vc, out resultValue))
+
+                        doRestart = false;
+                        if ((table.TryReadTable(out luaTable) && luaTable.TryGetValue(vc, out resultValue)) || GetTableValueSlowPath(table, vc, ref context, out resultValue, out doRestart))
                         {
+                            if (doRestart) goto Restart;
                             Unsafe.Add(ref stackHead, iA) = resultValue;
                             Unsafe.Add(ref stackHead, iA + 1) = table;
                             stack.NotifyTop(iA + frameBase + 2);
                             continue;
                         }
 
-
-                        if (TryGetMetaTableValue(table, vc, ref context, out doRestart))
-                        {
-                            if (doRestart) goto Restart;
-                            continue;
-                        }
-
                         postOperation = PostOperationType.Self;
                         return true;
                     case OpCode.Add:
@@ -1315,126 +1292,176 @@ public static partial class LuaVirtualMachine
     }
 
     [MethodImpl(MethodImplOptions.NoInlining)]
-    static bool TryGetMetaTableValue(LuaValue table, LuaValue key, ref VirtualMachineExecutionContext context, out bool doRestart)
+    static bool GetTableValueSlowPath(LuaValue table, LuaValue key, ref VirtualMachineExecutionContext context, out LuaValue value, out bool doRestart)
     {
-        var isSelf = context.Instruction.OpCode == OpCode.Self;
+        var targetTable = table;
+        const int MAX_LOOP = 100;
         doRestart = false;
-        var state = context.State;
-        if (table.TryGetMetamethod(state, Metamethods.Index, out var metamethod))
+        var skip = targetTable.Type == LuaValueType.Table;
+        for (int i = 0; i < MAX_LOOP; i++)
         {
-            if (!metamethod.TryReadFunction(out var indexTable))
+            if (table.TryReadTable(out var luaTable))
             {
-                LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "call", metamethod);
-            }
-
-            var stack = context.Stack;
-            stack.Push(table);
-            stack.Push(key);
-
+                if (!skip && luaTable.TryGetValue(key, out value))
+                {
+                    return true;
+                }
 
-            var newFrame = indexTable.CreateNewFrame(ref context, stack.Count - 2);
+                skip = false;
 
-            context.Thread.PushCallStackFrame(newFrame);
+                var metatable = luaTable.Metatable;
+                if (metatable != null && metatable.TryGetValue(Metamethods.Index, out table))
+                {
+                    goto Function;
+                }
 
-            if (indexTable is Closure)
-            {
-                context.Push(newFrame);
-                doRestart = true;
+                value = default;
                 return true;
             }
 
-            var task = indexTable.Invoke(ref context, newFrame, 2);
-
-            if (!task.IsCompleted)
+            if (!table.TryGetMetamethod(context.State, Metamethods.Index, out var metatableValue))
             {
-                context.Task = task;
-                return false;
+                LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "index", table);
             }
 
-            var awaiter = task.GetAwaiter();
-            context.Thread.PopCallStackFrame();
-            var ra = context.Instruction.A + context.FrameBase;
-            var resultCount = awaiter.GetResult();
-            context.Stack.Get(ra) = resultCount == 0 ? default : context.ResultsBuffer[0];
-            if (isSelf)
-            {
-                context.Stack.Get(ra + 1) = table;
-                context.Stack.NotifyTop(ra + 2);
-            }
-            else
+            table = metatableValue;
+        Function:
+            if (table.TryReadFunction(out var function))
             {
-                context.Stack.NotifyTop(ra + 1);
+                return CallGetTableFunc(targetTable, function, key, ref context, out value, out doRestart);
             }
-
-            context.ClearResultsBuffer(resultCount);
-            return true;
         }
 
-        if (table.Type == LuaValueType.Table)
-        {
-            var ra = context.Instruction.A + context.FrameBase;
-            context.Stack.Get(ra) = default;
-            if (isSelf)
-            {
-                context.Stack.Get(ra + 1) = table;
-                context.Stack.NotifyTop(ra + 2);
-            }
-            else
-            {
-                context.Stack.NotifyTop(ra + 1);
-            }
+        throw new LuaRuntimeException(GetTracebacks(ref context), "loop in gettable");
+    }
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static bool CallGetTableFunc(LuaValue table, LuaFunction indexTable, LuaValue key, ref VirtualMachineExecutionContext context, out LuaValue result, out bool doRestart)
+    {
+        doRestart = false;
+        var stack = context.Stack;
+        stack.Push(table);
+        stack.Push(key);
+        var newFrame = indexTable.CreateNewFrame(ref context, stack.Count - 2);
 
+        context.Thread.PushCallStackFrame(newFrame);
+
+        if (indexTable is Closure)
+        {
+            context.Push(newFrame);
+            doRestart = true;
+            result = default;
             return true;
         }
 
+        var task = indexTable.Invoke(ref context, newFrame, 2);
 
-        LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "index", table);
-        return false;
+        if (!task.IsCompleted)
+        {
+            context.Task = task;
+            result = default;
+            return false;
+        }
+
+        var awaiter = task.GetAwaiter();
+        context.Thread.PopCallStackFrame();
+        var resultCount = awaiter.GetResult();
+        result = resultCount == 0 ? default : context.ResultsBuffer[0];
+        context.ClearResultsBuffer(resultCount);
+        return true;
     }
 
-    static bool TrySetMetaTableValue(LuaValue table, LuaValue key, LuaValue value,
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static bool SetTableValueSlowPath(LuaValue table, LuaValue key, LuaValue value,
         ref VirtualMachineExecutionContext context, out bool doRestart)
     {
+        var targetTable = table;
+        const int MAX_LOOP = 100;
         doRestart = false;
-        var state = context.State;
-        if (table.TryGetMetamethod(state, Metamethods.NewIndex, out var metamethod))
+        var skip = targetTable.Type == LuaValueType.Table;
+        for (int i = 0; i < MAX_LOOP; i++)
         {
-            if (!metamethod.TryReadFunction(out var indexTable))
+            if (table.TryReadTable(out var luaTable))
             {
-                LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "call", metamethod);
-            }
+                ref var valueRef = ref (skip ? ref Unsafe.NullRef<LuaValue>() : ref luaTable.FindValue(key));
+                skip = false;
+                if (!Unsafe.IsNullRef(ref valueRef) && valueRef.Type != LuaValueType.Nil)
+                {
+                    valueRef = value;
+                    return true;
+                }
 
-            var thread = context.Thread;
-            var stack = thread.Stack;
-            stack.Push(table);
-            stack.Push(key);
-            stack.Push(value);
-            var newFrame = indexTable.CreateNewFrame(ref context, stack.Count - 3);
+                var metatable = luaTable.Metatable;
+                if (metatable == null || !metatable.TryGetValue(Metamethods.NewIndex, out table))
+                {
+                    if (Unsafe.IsNullRef(ref valueRef))
+                    {
+                        luaTable[key] = value;
+                        return true;
+                    }
 
-            context.Thread.PushCallStackFrame(newFrame);
+                    valueRef = value;
+                    return true;
+                }
+
+                goto Function;
+            }
 
-            if (indexTable is Closure)
+            if (!table.TryGetMetamethod(context.State, Metamethods.NewIndex, out var metatableValue))
             {
-                context.Push(newFrame);
-                doRestart = true;
-                return true;
+                LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "index", table);
             }
 
-            var task = indexTable.Invoke(ref context, newFrame, 3);
-            if (!task.IsCompleted)
+            table = metatableValue;
+
+        Function:
+            if (table.TryReadFunction(out var function))
             {
-                context.Task = task;
-                return false;
+                return CallSetTableFunc(targetTable, function, key, value, ref context, out doRestart);
             }
+        }
+
+        throw new LuaRuntimeException(GetTracebacks(ref context), "loop in settable");
+    }
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static bool CallSetTableFunc(LuaValue table, LuaFunction newIndexFunction, LuaValue key, LuaValue value, ref VirtualMachineExecutionContext context, out bool doRestart)
+    {
+        doRestart = false;
+        var thread = context.Thread;
+        var stack = thread.Stack;
+        stack.Push(table);
+        stack.Push(key);
+        stack.Push(value);
+        var newFrame = newIndexFunction.CreateNewFrame(ref context, stack.Count - 3);
 
-            thread.PopCallStackFrame();
+        context.Thread.PushCallStackFrame(newFrame);
+
+        if (newIndexFunction is Closure)
+        {
+            context.Push(newFrame);
+            doRestart = true;
             return true;
         }
 
-        LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "index", table);
-        return false;
+        var task = newIndexFunction.Invoke(ref context, newFrame, 3);
+        if (!task.IsCompleted)
+        {
+            context.Task = task;
+            return false;
+        }
+
+        var resultCount = task.GetAwaiter().GetResult();
+        if (0 < resultCount)
+        {
+            context.ClearResultsBuffer(resultCount);
+        }
+
+        thread.PopCallStackFrame();
+        return true;
     }
 
+
     [MethodImpl(MethodImplOptions.NoInlining)]
     static bool ExecuteBinaryOperationMetaMethod(LuaValue vb, LuaValue vc,
         ref VirtualMachineExecutionContext context, string name, string description, out bool doRestart)

+ 13 - 0
src/Lua/Runtime/UpValue.cs

@@ -45,6 +45,19 @@ public sealed class UpValue
             return Thread!.Stack.Get(RegisterIndex);
         }
     }
+    
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    internal ref readonly LuaValue GetValueRef()
+    {
+        if (IsClosed)
+        {
+            return ref value;
+        }
+        else
+        {
+            return ref Thread!.Stack.Get(RegisterIndex);
+        }
+    }
 
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
     public void SetValue(LuaValue value)

+ 40 - 0
tests/Lua.Tests/MetatableTests.cs

@@ -46,4 +46,44 @@ return a + b
             Assert.That(table[3].Read<double>(), Is.EqualTo(9));
         });
     }
+
+    [Test]
+    public async Task Test_Metamethod_Index()
+    {
+        var source = @"
+metatable = {
+    __index = {x=1}
+}
+
+local a = {}
+setmetatable(a, metatable)
+assert(a.x == 1)
+metatable.__index= nil
+assert(a.x == nil)
+metatable.__index= function(a,b) return b end
+assert(a.x == 'x')
+";
+        await state.DoStringAsync(source);
+    }
+
+    [Test]
+    public async Task Test_Metamethod_NewIndex()
+    {
+        var source = @"
+metatable = {
+    __newindex = {}
+}
+
+local a = {}
+a.x = 1
+setmetatable(a, metatable)
+a.x = 2
+assert(a.x == 2)
+a.x = nil
+a.x = 2
+assert(a.x == nil)
+assert(metatable.__newindex.x == 2)
+";
+        await state.DoStringAsync(source);
+    }
 }