Browse Source

Fix: metatable index newindex with table

Akeit0 11 months ago
parent
commit
f607866b4b
3 changed files with 223 additions and 105 deletions
  1. 19 0
      src/Lua/LuaTable.cs
  2. 164 105
      src/Lua/Runtime/LuaVirtualMachine.cs
  3. 40 0
      tests/Lua.Tests/MetatableTests.cs

+ 19 - 0
src/Lua/LuaTable.cs

@@ -109,6 +109,25 @@ public sealed class LuaTable
 
 
         return dictionary.TryGetValue(key, out value) && value.Type is not LuaValueType.Nil;
         return dictionary.TryGetValue(key, out value) && value.Type is not LuaValueType.Nil;
     }
     }
+    
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    internal ref LuaValue FindValue(LuaValue key)
+    {
+        if (key.Type is LuaValueType.Nil)
+        {
+            ThrowIndexIsNil();
+        }
+
+        if (TryGetInteger(key, out var index))
+        {
+            if (index > 0 && index <= array.Length)
+            {
+                return ref array[index - 1];
+            }
+        }
+
+        return ref dictionary.FindValue(key,out _);
+    }
 
 
     public bool ContainsKey(LuaValue key)
     public bool ContainsKey(LuaValue key)
     {
     {

+ 164 - 105
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -313,15 +313,11 @@ public static partial class LuaVirtualMachine
                         ref readonly var vc = ref RKC(ref stackHead, ref constHead, instruction);
                         ref readonly var vc = ref RKC(ref stackHead, ref constHead, instruction);
                         var table = context.Closure.GetUpValue(instruction.B);
                         var table = context.Closure.GetUpValue(instruction.B);
 
 
-                        if (table.TryReadTable(out var luaTable) && luaTable.TryGetValue(vc, out var resultValue))
-                        {
-                            stack.GetWithNotifyTop(instruction.A + frameBase) = resultValue;
-                            continue;
-                        }
-
-                        if (TryGetMetaTableValue(table, vc, ref context, out var doRestart))
+                        var doRestart = false;
+                        if (table.TryReadTable(out var luaTable) && luaTable.TryGetValue(vc, out var resultValue) || TryGetValueWithSync(table, vc, ref context, out resultValue, out doRestart))
                         {
                         {
                             if (doRestart) goto Restart;
                             if (doRestart) goto Restart;
+                            stack.GetWithNotifyTop(instruction.A + frameBase) = resultValue;
                             continue;
                             continue;
                         }
                         }
 
 
@@ -332,15 +328,11 @@ public static partial class LuaVirtualMachine
                         stackHead = ref stack.FastGet(frameBase);
                         stackHead = ref stack.FastGet(frameBase);
                         ref readonly var vb = ref Unsafe.Add(ref stackHead, instruction.UIntB);
                         ref readonly var vb = ref Unsafe.Add(ref stackHead, instruction.UIntB);
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
-                        if (vb.TryReadTable(out luaTable) && luaTable.TryGetValue(vc, out resultValue))
-                        {
-                            stack.GetWithNotifyTop(instruction.A + frameBase) = resultValue;
-                            continue;
-                        }
-
-                        if (TryGetMetaTableValue(vb, vc, ref context, out doRestart))
+                        doRestart = false;
+                        if (vb.TryReadTable(out luaTable) && luaTable.TryGetValue(vc, out resultValue) || TryGetValueWithSync(vb, vc, ref context, out resultValue, out doRestart))
                         {
                         {
                             if (doRestart) goto Restart;
                             if (doRestart) goto Restart;
+                            stack.GetWithNotifyTop(instruction.A + frameBase) = resultValue;
                             continue;
                             continue;
                         }
                         }
 
 
@@ -361,16 +353,22 @@ public static partial class LuaVirtualMachine
 
 
                         table = context.Closure.GetUpValue(instruction.A);
                         table = context.Closure.GetUpValue(instruction.A);
 
 
+
                         if (table.TryReadTable(out luaTable))
                         if (table.TryReadTable(out luaTable))
                         {
                         {
-                            luaTable[vb] = RKC(ref stackHead, ref constHead, instruction);
-                            continue;
+                            ref var valueRef = ref luaTable.FindValue(vb);
+                            if (!Unsafe.IsNullRef(ref valueRef) && valueRef.Type != LuaValueType.Nil)
+                            {
+                                valueRef = RKC(ref stackHead, ref constHead, instruction);
+                                continue;
+                            }
                         }
                         }
 
 
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
-                        if (TrySetMetaTableValue(table, vb, vc, ref context, out doRestart))
+                        if (TrySetMetaTableValueWithSync(table, vb, vc, ref context, out doRestart))
                         {
                         {
                             if (doRestart) goto Restart;
                             if (doRestart) goto Restart;
+
                             continue;
                             continue;
                         }
                         }
 
 
@@ -399,15 +397,16 @@ public static partial class LuaVirtualMachine
 
 
                         if (table.TryReadTable(out luaTable))
                         if (table.TryReadTable(out luaTable))
                         {
                         {
-                            if (luaTable.Metatable == null || !luaTable.Metatable.ContainsKey(Metamethods.NewIndex) || luaTable.ContainsKey(vb))
+                            ref var valueRef = ref luaTable.FindValue(vb);
+                            if (!Unsafe.IsNullRef(ref valueRef) && valueRef.Type != LuaValueType.Nil)
                             {
                             {
-                                luaTable[vb] = RKC(ref stackHead, ref constHead, instruction);
+                                valueRef = RKC(ref stackHead, ref constHead, instruction);
                                 continue;
                                 continue;
                             }
                             }
                         }
                         }
 
 
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
-                        if (TrySetMetaTableValue(table, vb, vc, ref context, out doRestart))
+                        if (TrySetMetaTableValueWithSync(table, vb, vc, ref context, out doRestart))
                         {
                         {
                             if (doRestart) goto Restart;
                             if (doRestart) goto Restart;
                             continue;
                             continue;
@@ -425,20 +424,30 @@ public static partial class LuaVirtualMachine
                         stackHead = ref stack.FastGet(frameBase);
                         stackHead = ref stack.FastGet(frameBase);
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
                         table = Unsafe.Add(ref stackHead, instruction.UIntB);
                         table = Unsafe.Add(ref stackHead, instruction.UIntB);
-                        if (table.TryReadTable(out luaTable) && luaTable.TryGetValue(vc, out resultValue))
-                        {
-                            Unsafe.Add(ref stackHead, iA) = resultValue;
-                            Unsafe.Add(ref stackHead, iA + 1) = table;
-                            stack.NotifyTop(iA + frameBase + 2);
-                            continue;
-                        }
 
 
 
 
-                        if (TryGetMetaTableValue(table, vc, ref context, out doRestart))
+                        if (TryGetValueWithSync(table, vc, ref context, out resultValue, out doRestart))
                         {
                         {
                             if (doRestart) goto Restart;
                             if (doRestart) goto Restart;
+                            Unsafe.Add(ref stackHead, iA) = resultValue;
+                            Unsafe.Add(ref stackHead, iA + 1) = table;
+                            stack.NotifyTop(iA + frameBase + 2);
                             continue;
                             continue;
                         }
                         }
+                        // if (table.TryReadTable(out luaTable) && luaTable.TryGetValue(vc, out resultValue))
+                        // {
+                        //     Unsafe.Add(ref stackHead, iA) = resultValue;
+                        //     Unsafe.Add(ref stackHead, iA + 1) = table;
+                        //     stack.NotifyTop(iA + frameBase + 2);
+                        //     continue;
+                        // }
+                        //
+                        //
+                        // if (TryGetMetaTableValue(table, vc, ref context, out doRestart))
+                        // {
+                        //     if (doRestart) goto Restart;
+                        //     continue;
+                        // }
 
 
                         postOperation = PostOperationType.Self;
                         postOperation = PostOperationType.Self;
                         return true;
                         return true;
@@ -1315,126 +1324,176 @@ public static partial class LuaVirtualMachine
     }
     }
 
 
     [MethodImpl(MethodImplOptions.NoInlining)]
     [MethodImpl(MethodImplOptions.NoInlining)]
-    static bool TryGetMetaTableValue(LuaValue table, LuaValue key, ref VirtualMachineExecutionContext context, out bool doRestart)
+    static bool TryGetValueWithSync(LuaValue table, LuaValue key, ref VirtualMachineExecutionContext context, out LuaValue value, out bool doRestart)
     {
     {
-        var isSelf = context.Instruction.OpCode == OpCode.Self;
+        var targetTable = table;
+        const int MAX_LOOP = 100;
         doRestart = false;
         doRestart = false;
-        var state = context.State;
-        if (table.TryGetMetamethod(state, Metamethods.Index, out var metamethod))
+        var skip = targetTable.Type == LuaValueType.Table;
+        for (int i = 0; i < MAX_LOOP; i++)
         {
         {
-            if (!metamethod.TryReadFunction(out var indexTable))
+            if (table.TryReadTable(out var luaTable))
             {
             {
-                LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "call", metamethod);
-            }
-
-            var stack = context.Stack;
-            stack.Push(table);
-            stack.Push(key);
-
+                if (!skip && luaTable.TryGetValue(key, out value))
+                {
+                    return true;
+                }
 
 
-            var newFrame = indexTable.CreateNewFrame(ref context, stack.Count - 2);
+                skip = false;
 
 
-            context.Thread.PushCallStackFrame(newFrame);
+                var metatable = luaTable.Metatable;
+                if (metatable != null && metatable.TryGetValue(Metamethods.Index, out table))
+                {
+                    goto Function;
+                }
 
 
-            if (indexTable is Closure)
-            {
-                context.Push(newFrame);
-                doRestart = true;
+                value = default;
                 return true;
                 return true;
             }
             }
 
 
-            var task = indexTable.Invoke(ref context, newFrame, 2);
-
-            if (!task.IsCompleted)
+            if (!table.TryGetMetamethod(context.State, Metamethods.Index, out var metatableValue))
             {
             {
-                context.Task = task;
-                return false;
+                LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "index", table);
             }
             }
 
 
-            var awaiter = task.GetAwaiter();
-            context.Thread.PopCallStackFrame();
-            var ra = context.Instruction.A + context.FrameBase;
-            var resultCount = awaiter.GetResult();
-            context.Stack.Get(ra) = resultCount == 0 ? default : context.ResultsBuffer[0];
-            if (isSelf)
-            {
-                context.Stack.Get(ra + 1) = table;
-                context.Stack.NotifyTop(ra + 2);
-            }
-            else
+            table = metatableValue;
+        Function:
+            if (table.TryReadFunction(out var function))
             {
             {
-                context.Stack.NotifyTop(ra + 1);
+                return CallGetTableFunc(targetTable, function, key, ref context, out value, out doRestart);
             }
             }
-
-            context.ClearResultsBuffer(resultCount);
-            return true;
         }
         }
 
 
-        if (table.Type == LuaValueType.Table)
-        {
-            var ra = context.Instruction.A + context.FrameBase;
-            context.Stack.Get(ra) = default;
-            if (isSelf)
-            {
-                context.Stack.Get(ra + 1) = table;
-                context.Stack.NotifyTop(ra + 2);
-            }
-            else
-            {
-                context.Stack.NotifyTop(ra + 1);
-            }
+        throw new LuaRuntimeException(GetTracebacks(ref context), "loop in gettable");
+    }
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static bool CallGetTableFunc(LuaValue table, LuaFunction indexTable, LuaValue key, ref VirtualMachineExecutionContext context, out LuaValue result, out bool doRestart)
+    {
+        doRestart = false;
+        var stack = context.Stack;
+        stack.Push(table);
+        stack.Push(key);
+        var newFrame = indexTable.CreateNewFrame(ref context, stack.Count - 2);
 
 
+        context.Thread.PushCallStackFrame(newFrame);
+
+        if (indexTable is Closure)
+        {
+            context.Push(newFrame);
+            doRestart = true;
+            result = default;
             return true;
             return true;
         }
         }
 
 
+        var task = indexTable.Invoke(ref context, newFrame, 2);
 
 
-        LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "index", table);
-        return false;
+        if (!task.IsCompleted)
+        {
+            context.Task = task;
+            result = default;
+            return false;
+        }
+
+        var awaiter = task.GetAwaiter();
+        context.Thread.PopCallStackFrame();
+        var resultCount = awaiter.GetResult();
+        result = resultCount == 0 ? default : context.ResultsBuffer[0];
+        context.ClearResultsBuffer(resultCount);
+        return true;
     }
     }
 
 
-    static bool TrySetMetaTableValue(LuaValue table, LuaValue key, LuaValue value,
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static bool TrySetMetaTableValueWithSync(LuaValue table, LuaValue key, LuaValue value,
         ref VirtualMachineExecutionContext context, out bool doRestart)
         ref VirtualMachineExecutionContext context, out bool doRestart)
     {
     {
+        var targetTable = table;
+        const int MAX_LOOP = 100;
         doRestart = false;
         doRestart = false;
-        var state = context.State;
-        if (table.TryGetMetamethod(state, Metamethods.NewIndex, out var metamethod))
+        var skip = targetTable.Type == LuaValueType.Table;
+        for (int i = 0; i < MAX_LOOP; i++)
         {
         {
-            if (!metamethod.TryReadFunction(out var indexTable))
+            if (table.TryReadTable(out var luaTable))
             {
             {
-                LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "call", metamethod);
-            }
+                ref var valueRef = ref (skip ? ref Unsafe.NullRef<LuaValue>() : ref luaTable.FindValue(key));
+                skip = false;
+                if (!Unsafe.IsNullRef(ref valueRef) && valueRef.Type != LuaValueType.Nil)
+                {
+                    luaTable[key] = value;
+                    return true;
+                }
+
+                var metatable = luaTable.Metatable;
+                if (metatable == null || !metatable.TryGetValue(Metamethods.NewIndex, out table))
+                {
+                    if (Unsafe.IsNullRef(ref valueRef))
+                    {
+                        luaTable[key] = value;
+                        return true;
+                    }
 
 
-            var thread = context.Thread;
-            var stack = thread.Stack;
-            stack.Push(table);
-            stack.Push(key);
-            stack.Push(value);
-            var newFrame = indexTable.CreateNewFrame(ref context, stack.Count - 3);
+                    valueRef = value;
+                    return true;
+                }
 
 
-            context.Thread.PushCallStackFrame(newFrame);
+                goto Function;
+            }
 
 
-            if (indexTable is Closure)
+            if (!table.TryGetMetamethod(context.State, Metamethods.NewIndex, out var metatableValue))
             {
             {
-                context.Push(newFrame);
-                doRestart = true;
-                return true;
+                LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "index", table);
             }
             }
 
 
-            var task = indexTable.Invoke(ref context, newFrame, 3);
-            if (!task.IsCompleted)
+            table = metatableValue;
+
+        Function:
+            if (table.TryReadFunction(out var function))
             {
             {
-                context.Task = task;
-                return false;
+                return CallSetTableFunc(targetTable, function, key, value, ref context, out doRestart);
             }
             }
+        }
+
+        throw new LuaRuntimeException(GetTracebacks(ref context), "loop in settable");
+    }
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static bool CallSetTableFunc(LuaValue table, LuaFunction newIndexFunction, LuaValue key, LuaValue value, ref VirtualMachineExecutionContext context, out bool doRestart)
+    {
+        doRestart = false;
+        var thread = context.Thread;
+        var stack = thread.Stack;
+        stack.Push(table);
+        stack.Push(key);
+        stack.Push(value);
+        var newFrame = newIndexFunction.CreateNewFrame(ref context, stack.Count - 3);
+
+        context.Thread.PushCallStackFrame(newFrame);
 
 
-            thread.PopCallStackFrame();
+        if (newIndexFunction is Closure)
+        {
+            context.Push(newFrame);
+            doRestart = true;
             return true;
             return true;
         }
         }
 
 
-        LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "index", table);
-        return false;
+        var task = newIndexFunction.Invoke(ref context, newFrame, 3);
+        if (!task.IsCompleted)
+        {
+            context.Task = task;
+            return false;
+        }
+
+        var resultCount = task.GetAwaiter().GetResult();
+        if (0 < resultCount)
+        {
+            context.ClearResultsBuffer(resultCount);
+        }
+
+        thread.PopCallStackFrame();
+        return true;
     }
     }
 
 
+
     [MethodImpl(MethodImplOptions.NoInlining)]
     [MethodImpl(MethodImplOptions.NoInlining)]
     static bool ExecuteBinaryOperationMetaMethod(LuaValue vb, LuaValue vc,
     static bool ExecuteBinaryOperationMetaMethod(LuaValue vb, LuaValue vc,
         ref VirtualMachineExecutionContext context, string name, string description, out bool doRestart)
         ref VirtualMachineExecutionContext context, string name, string description, out bool doRestart)

+ 40 - 0
tests/Lua.Tests/MetatableTests.cs

@@ -46,4 +46,44 @@ return a + b
             Assert.That(table[3].Read<double>(), Is.EqualTo(9));
             Assert.That(table[3].Read<double>(), Is.EqualTo(9));
         });
         });
     }
     }
+
+    [Test]
+    public async Task Test_Metamethod_Index()
+    {
+        var source = @"
+metatable = {
+    __index = {x=1}
+}
+
+local a = {}
+setmetatable(a, metatable)
+assert(a.x == 1)
+metatable.__index= nil
+assert(a.x == nil)
+metatable.__index= function(a,b) return b end
+assert(a.x == 'x')
+";
+        await state.DoStringAsync(source);
+    }
+
+    [Test]
+    public async Task Test_Metamethod_NewIndex()
+    {
+        var source = @"
+metatable = {
+    __newindex = {}
+}
+
+local a = {}
+a.x = 1
+setmetatable(a, metatable)
+a.x = 2
+assert(a.x == 2)
+a.x = nil
+a.x = 2
+assert(a.x == nil)
+assert(metatable.__newindex.x == 2)
+";
+        await state.DoStringAsync(source);
+    }
 }
 }