Browse Source

Fix: metatable index newindex with table

Akeit0 11 months ago
parent
commit
f607866b4b
3 changed files with 223 additions and 105 deletions
  1. 19 0
      src/Lua/LuaTable.cs
  2. 164 105
      src/Lua/Runtime/LuaVirtualMachine.cs
  3. 40 0
      tests/Lua.Tests/MetatableTests.cs

+ 19 - 0
src/Lua/LuaTable.cs

@@ -109,6 +109,25 @@ public sealed class LuaTable
 
         return dictionary.TryGetValue(key, out value) && value.Type is not LuaValueType.Nil;
     }
+    
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    internal ref LuaValue FindValue(LuaValue key)
+    {
+        if (key.Type is LuaValueType.Nil)
+        {
+            ThrowIndexIsNil();
+        }
+
+        if (TryGetInteger(key, out var index))
+        {
+            if (index > 0 && index <= array.Length)
+            {
+                return ref array[index - 1];
+            }
+        }
+
+        return ref dictionary.FindValue(key,out _);
+    }
 
     public bool ContainsKey(LuaValue key)
     {

+ 164 - 105
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -313,15 +313,11 @@ public static partial class LuaVirtualMachine
                         ref readonly var vc = ref RKC(ref stackHead, ref constHead, instruction);
                         var table = context.Closure.GetUpValue(instruction.B);
 
-                        if (table.TryReadTable(out var luaTable) && luaTable.TryGetValue(vc, out var resultValue))
-                        {
-                            stack.GetWithNotifyTop(instruction.A + frameBase) = resultValue;
-                            continue;
-                        }
-
-                        if (TryGetMetaTableValue(table, vc, ref context, out var doRestart))
+                        var doRestart = false;
+                        if (table.TryReadTable(out var luaTable) && luaTable.TryGetValue(vc, out var resultValue) || TryGetValueWithSync(table, vc, ref context, out resultValue, out doRestart))
                         {
                             if (doRestart) goto Restart;
+                            stack.GetWithNotifyTop(instruction.A + frameBase) = resultValue;
                             continue;
                         }
 
@@ -332,15 +328,11 @@ public static partial class LuaVirtualMachine
                         stackHead = ref stack.FastGet(frameBase);
                         ref readonly var vb = ref Unsafe.Add(ref stackHead, instruction.UIntB);
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
-                        if (vb.TryReadTable(out luaTable) && luaTable.TryGetValue(vc, out resultValue))
-                        {
-                            stack.GetWithNotifyTop(instruction.A + frameBase) = resultValue;
-                            continue;
-                        }
-
-                        if (TryGetMetaTableValue(vb, vc, ref context, out doRestart))
+                        doRestart = false;
+                        if (vb.TryReadTable(out luaTable) && luaTable.TryGetValue(vc, out resultValue) || TryGetValueWithSync(vb, vc, ref context, out resultValue, out doRestart))
                         {
                             if (doRestart) goto Restart;
+                            stack.GetWithNotifyTop(instruction.A + frameBase) = resultValue;
                             continue;
                         }
 
@@ -361,16 +353,22 @@ public static partial class LuaVirtualMachine
 
                         table = context.Closure.GetUpValue(instruction.A);
 
+
                         if (table.TryReadTable(out luaTable))
                         {
-                            luaTable[vb] = RKC(ref stackHead, ref constHead, instruction);
-                            continue;
+                            ref var valueRef = ref luaTable.FindValue(vb);
+                            if (!Unsafe.IsNullRef(ref valueRef) && valueRef.Type != LuaValueType.Nil)
+                            {
+                                valueRef = RKC(ref stackHead, ref constHead, instruction);
+                                continue;
+                            }
                         }
 
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
-                        if (TrySetMetaTableValue(table, vb, vc, ref context, out doRestart))
+                        if (TrySetMetaTableValueWithSync(table, vb, vc, ref context, out doRestart))
                         {
                             if (doRestart) goto Restart;
+
                             continue;
                         }
 
@@ -399,15 +397,16 @@ public static partial class LuaVirtualMachine
 
                         if (table.TryReadTable(out luaTable))
                         {
-                            if (luaTable.Metatable == null || !luaTable.Metatable.ContainsKey(Metamethods.NewIndex) || luaTable.ContainsKey(vb))
+                            ref var valueRef = ref luaTable.FindValue(vb);
+                            if (!Unsafe.IsNullRef(ref valueRef) && valueRef.Type != LuaValueType.Nil)
                             {
-                                luaTable[vb] = RKC(ref stackHead, ref constHead, instruction);
+                                valueRef = RKC(ref stackHead, ref constHead, instruction);
                                 continue;
                             }
                         }
 
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
-                        if (TrySetMetaTableValue(table, vb, vc, ref context, out doRestart))
+                        if (TrySetMetaTableValueWithSync(table, vb, vc, ref context, out doRestart))
                         {
                             if (doRestart) goto Restart;
                             continue;
@@ -425,20 +424,30 @@ public static partial class LuaVirtualMachine
                         stackHead = ref stack.FastGet(frameBase);
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
                         table = Unsafe.Add(ref stackHead, instruction.UIntB);
-                        if (table.TryReadTable(out luaTable) && luaTable.TryGetValue(vc, out resultValue))
-                        {
-                            Unsafe.Add(ref stackHead, iA) = resultValue;
-                            Unsafe.Add(ref stackHead, iA + 1) = table;
-                            stack.NotifyTop(iA + frameBase + 2);
-                            continue;
-                        }
 
 
-                        if (TryGetMetaTableValue(table, vc, ref context, out doRestart))
+                        if (TryGetValueWithSync(table, vc, ref context, out resultValue, out doRestart))
                         {
                             if (doRestart) goto Restart;
+                            Unsafe.Add(ref stackHead, iA) = resultValue;
+                            Unsafe.Add(ref stackHead, iA + 1) = table;
+                            stack.NotifyTop(iA + frameBase + 2);
                             continue;
                         }
+                        // if (table.TryReadTable(out luaTable) && luaTable.TryGetValue(vc, out resultValue))
+                        // {
+                        //     Unsafe.Add(ref stackHead, iA) = resultValue;
+                        //     Unsafe.Add(ref stackHead, iA + 1) = table;
+                        //     stack.NotifyTop(iA + frameBase + 2);
+                        //     continue;
+                        // }
+                        //
+                        //
+                        // if (TryGetMetaTableValue(table, vc, ref context, out doRestart))
+                        // {
+                        //     if (doRestart) goto Restart;
+                        //     continue;
+                        // }
 
                         postOperation = PostOperationType.Self;
                         return true;
@@ -1315,126 +1324,176 @@ public static partial class LuaVirtualMachine
     }
 
     [MethodImpl(MethodImplOptions.NoInlining)]
-    static bool TryGetMetaTableValue(LuaValue table, LuaValue key, ref VirtualMachineExecutionContext context, out bool doRestart)
+    static bool TryGetValueWithSync(LuaValue table, LuaValue key, ref VirtualMachineExecutionContext context, out LuaValue value, out bool doRestart)
     {
-        var isSelf = context.Instruction.OpCode == OpCode.Self;
+        var targetTable = table;
+        const int MAX_LOOP = 100;
         doRestart = false;
-        var state = context.State;
-        if (table.TryGetMetamethod(state, Metamethods.Index, out var metamethod))
+        var skip = targetTable.Type == LuaValueType.Table;
+        for (int i = 0; i < MAX_LOOP; i++)
         {
-            if (!metamethod.TryReadFunction(out var indexTable))
+            if (table.TryReadTable(out var luaTable))
             {
-                LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "call", metamethod);
-            }
-
-            var stack = context.Stack;
-            stack.Push(table);
-            stack.Push(key);
-
+                if (!skip && luaTable.TryGetValue(key, out value))
+                {
+                    return true;
+                }
 
-            var newFrame = indexTable.CreateNewFrame(ref context, stack.Count - 2);
+                skip = false;
 
-            context.Thread.PushCallStackFrame(newFrame);
+                var metatable = luaTable.Metatable;
+                if (metatable != null && metatable.TryGetValue(Metamethods.Index, out table))
+                {
+                    goto Function;
+                }
 
-            if (indexTable is Closure)
-            {
-                context.Push(newFrame);
-                doRestart = true;
+                value = default;
                 return true;
             }
 
-            var task = indexTable.Invoke(ref context, newFrame, 2);
-
-            if (!task.IsCompleted)
+            if (!table.TryGetMetamethod(context.State, Metamethods.Index, out var metatableValue))
             {
-                context.Task = task;
-                return false;
+                LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "index", table);
             }
 
-            var awaiter = task.GetAwaiter();
-            context.Thread.PopCallStackFrame();
-            var ra = context.Instruction.A + context.FrameBase;
-            var resultCount = awaiter.GetResult();
-            context.Stack.Get(ra) = resultCount == 0 ? default : context.ResultsBuffer[0];
-            if (isSelf)
-            {
-                context.Stack.Get(ra + 1) = table;
-                context.Stack.NotifyTop(ra + 2);
-            }
-            else
+            table = metatableValue;
+        Function:
+            if (table.TryReadFunction(out var function))
             {
-                context.Stack.NotifyTop(ra + 1);
+                return CallGetTableFunc(targetTable, function, key, ref context, out value, out doRestart);
             }
-
-            context.ClearResultsBuffer(resultCount);
-            return true;
         }
 
-        if (table.Type == LuaValueType.Table)
-        {
-            var ra = context.Instruction.A + context.FrameBase;
-            context.Stack.Get(ra) = default;
-            if (isSelf)
-            {
-                context.Stack.Get(ra + 1) = table;
-                context.Stack.NotifyTop(ra + 2);
-            }
-            else
-            {
-                context.Stack.NotifyTop(ra + 1);
-            }
+        throw new LuaRuntimeException(GetTracebacks(ref context), "loop in gettable");
+    }
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static bool CallGetTableFunc(LuaValue table, LuaFunction indexTable, LuaValue key, ref VirtualMachineExecutionContext context, out LuaValue result, out bool doRestart)
+    {
+        doRestart = false;
+        var stack = context.Stack;
+        stack.Push(table);
+        stack.Push(key);
+        var newFrame = indexTable.CreateNewFrame(ref context, stack.Count - 2);
 
+        context.Thread.PushCallStackFrame(newFrame);
+
+        if (indexTable is Closure)
+        {
+            context.Push(newFrame);
+            doRestart = true;
+            result = default;
             return true;
         }
 
+        var task = indexTable.Invoke(ref context, newFrame, 2);
 
-        LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "index", table);
-        return false;
+        if (!task.IsCompleted)
+        {
+            context.Task = task;
+            result = default;
+            return false;
+        }
+
+        var awaiter = task.GetAwaiter();
+        context.Thread.PopCallStackFrame();
+        var resultCount = awaiter.GetResult();
+        result = resultCount == 0 ? default : context.ResultsBuffer[0];
+        context.ClearResultsBuffer(resultCount);
+        return true;
     }
 
-    static bool TrySetMetaTableValue(LuaValue table, LuaValue key, LuaValue value,
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static bool TrySetMetaTableValueWithSync(LuaValue table, LuaValue key, LuaValue value,
         ref VirtualMachineExecutionContext context, out bool doRestart)
     {
+        var targetTable = table;
+        const int MAX_LOOP = 100;
         doRestart = false;
-        var state = context.State;
-        if (table.TryGetMetamethod(state, Metamethods.NewIndex, out var metamethod))
+        var skip = targetTable.Type == LuaValueType.Table;
+        for (int i = 0; i < MAX_LOOP; i++)
         {
-            if (!metamethod.TryReadFunction(out var indexTable))
+            if (table.TryReadTable(out var luaTable))
             {
-                LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "call", metamethod);
-            }
+                ref var valueRef = ref (skip ? ref Unsafe.NullRef<LuaValue>() : ref luaTable.FindValue(key));
+                skip = false;
+                if (!Unsafe.IsNullRef(ref valueRef) && valueRef.Type != LuaValueType.Nil)
+                {
+                    luaTable[key] = value;
+                    return true;
+                }
+
+                var metatable = luaTable.Metatable;
+                if (metatable == null || !metatable.TryGetValue(Metamethods.NewIndex, out table))
+                {
+                    if (Unsafe.IsNullRef(ref valueRef))
+                    {
+                        luaTable[key] = value;
+                        return true;
+                    }
 
-            var thread = context.Thread;
-            var stack = thread.Stack;
-            stack.Push(table);
-            stack.Push(key);
-            stack.Push(value);
-            var newFrame = indexTable.CreateNewFrame(ref context, stack.Count - 3);
+                    valueRef = value;
+                    return true;
+                }
 
-            context.Thread.PushCallStackFrame(newFrame);
+                goto Function;
+            }
 
-            if (indexTable is Closure)
+            if (!table.TryGetMetamethod(context.State, Metamethods.NewIndex, out var metatableValue))
             {
-                context.Push(newFrame);
-                doRestart = true;
-                return true;
+                LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "index", table);
             }
 
-            var task = indexTable.Invoke(ref context, newFrame, 3);
-            if (!task.IsCompleted)
+            table = metatableValue;
+
+        Function:
+            if (table.TryReadFunction(out var function))
             {
-                context.Task = task;
-                return false;
+                return CallSetTableFunc(targetTable, function, key, value, ref context, out doRestart);
             }
+        }
+
+        throw new LuaRuntimeException(GetTracebacks(ref context), "loop in settable");
+    }
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static bool CallSetTableFunc(LuaValue table, LuaFunction newIndexFunction, LuaValue key, LuaValue value, ref VirtualMachineExecutionContext context, out bool doRestart)
+    {
+        doRestart = false;
+        var thread = context.Thread;
+        var stack = thread.Stack;
+        stack.Push(table);
+        stack.Push(key);
+        stack.Push(value);
+        var newFrame = newIndexFunction.CreateNewFrame(ref context, stack.Count - 3);
+
+        context.Thread.PushCallStackFrame(newFrame);
 
-            thread.PopCallStackFrame();
+        if (newIndexFunction is Closure)
+        {
+            context.Push(newFrame);
+            doRestart = true;
             return true;
         }
 
-        LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "index", table);
-        return false;
+        var task = newIndexFunction.Invoke(ref context, newFrame, 3);
+        if (!task.IsCompleted)
+        {
+            context.Task = task;
+            return false;
+        }
+
+        var resultCount = task.GetAwaiter().GetResult();
+        if (0 < resultCount)
+        {
+            context.ClearResultsBuffer(resultCount);
+        }
+
+        thread.PopCallStackFrame();
+        return true;
     }
 
+
     [MethodImpl(MethodImplOptions.NoInlining)]
     static bool ExecuteBinaryOperationMetaMethod(LuaValue vb, LuaValue vc,
         ref VirtualMachineExecutionContext context, string name, string description, out bool doRestart)

+ 40 - 0
tests/Lua.Tests/MetatableTests.cs

@@ -46,4 +46,44 @@ return a + b
             Assert.That(table[3].Read<double>(), Is.EqualTo(9));
         });
     }
+
+    [Test]
+    public async Task Test_Metamethod_Index()
+    {
+        var source = @"
+metatable = {
+    __index = {x=1}
+}
+
+local a = {}
+setmetatable(a, metatable)
+assert(a.x == 1)
+metatable.__index= nil
+assert(a.x == nil)
+metatable.__index= function(a,b) return b end
+assert(a.x == 'x')
+";
+        await state.DoStringAsync(source);
+    }
+
+    [Test]
+    public async Task Test_Metamethod_NewIndex()
+    {
+        var source = @"
+metatable = {
+    __newindex = {}
+}
+
+local a = {}
+a.x = 1
+setmetatable(a, metatable)
+a.x = 2
+assert(a.x == 2)
+a.x = nil
+a.x = 2
+assert(a.x == nil)
+assert(metatable.__newindex.x == 2)
+";
+        await state.DoStringAsync(source);
+    }
 }