Browse Source

Merge pull request #72 from Akeit0/fix-metatable-index

Fix: metatable index/newindex with table
Annulus Games 11 months ago
parent
commit
a26a4abbc4

+ 20 - 1
src/Lua/LuaTable.cs

@@ -110,6 +110,25 @@ public sealed class LuaTable
         return dictionary.TryGetValue(key, out value) && value.Type is not LuaValueType.Nil;
         return dictionary.TryGetValue(key, out value) && value.Type is not LuaValueType.Nil;
     }
     }
 
 
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    internal ref LuaValue FindValue(LuaValue key)
+    {
+        if (key.Type is LuaValueType.Nil)
+        {
+            ThrowIndexIsNil();
+        }
+
+        if (TryGetInteger(key, out var index))
+        {
+            if (index > 0 && index <= array.Length)
+            {
+                return ref array[index - 1];
+            }
+        }
+
+        return ref dictionary.FindValue(key, out _);
+    }
+
     public bool ContainsKey(LuaValue key)
     public bool ContainsKey(LuaValue key)
     {
     {
         if (key.Type is LuaValueType.Nil)
         if (key.Type is LuaValueType.Nil)
@@ -194,7 +213,7 @@ public sealed class LuaTable
         }
         }
         else
         else
         {
         {
-            if(dictionary.TryGetNext(key, out pair))
+            if (dictionary.TryGetNext(key, out pair))
             {
             {
                 return true;
                 return true;
             }
             }

+ 10 - 4
src/Lua/Runtime/Closure.cs

@@ -31,6 +31,12 @@ public sealed class Closure : LuaFunction
         return upValues[index].GetValue();
         return upValues[index].GetValue();
     }
     }
 
 
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    internal ref readonly LuaValue GetUpValueRef(int index)
+    {
+        return ref upValues[index].GetValueRef();
+    }
+
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
     internal void SetUpValue(int index, LuaValue value)
     internal void SetUpValue(int index, LuaValue value)
     {
     {
@@ -43,17 +49,17 @@ public sealed class Closure : LuaFunction
         {
         {
             return state.GetOrAddUpValue(thread, thread.GetCallStackFrames()[^1].Base + description.Index);
             return state.GetOrAddUpValue(thread, thread.GetCallStackFrames()[^1].Base + description.Index);
         }
         }
-        
+
         if (description.Index == -1) // -1 is global environment
         if (description.Index == -1) // -1 is global environment
         {
         {
             return envUpValue;
             return envUpValue;
         }
         }
-        
+
         if (thread.GetCallStackFrames()[^1].Function is Closure parentClosure)
         if (thread.GetCallStackFrames()[^1].Function is Closure parentClosure)
         {
         {
-             return parentClosure.UpValues[description.Index];
+            return parentClosure.UpValues[description.Index];
         }
         }
-        
+
         throw new Exception();
         throw new Exception();
     }
     }
 }
 }

+ 155 - 128
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -92,20 +92,19 @@ public static partial class LuaVirtualMachine
             switch (opCode)
             switch (opCode)
             {
             {
                 case OpCode.Call:
                 case OpCode.Call:
-                {
-                    var c = callInstruction.C;
-                    if (c != 0)
                     {
                     {
-                        targetCount = c - 1;
-                    }
+                        var c = callInstruction.C;
+                        if (c != 0)
+                        {
+                            targetCount = c - 1;
+                        }
 
 
-                    break;
-                }
+                        break;
+                    }
                 case OpCode.TForCall:
                 case OpCode.TForCall:
                     target += 3;
                     target += 3;
                     targetCount = callInstruction.C;
                     targetCount = callInstruction.C;
                     break;
                     break;
-
                 case OpCode.Self:
                 case OpCode.Self:
                     Stack.Get(target) = result.Length == 0 ? LuaValue.Nil : result[0];
                     Stack.Get(target) = result.Length == 0 ? LuaValue.Nil : result[0];
                     Thread.PopCallStackFrameUnsafe(target + 2);
                     Thread.PopCallStackFrameUnsafe(target + 2);
@@ -266,7 +265,7 @@ public static partial class LuaVirtualMachine
 
 
         try
         try
         {
         {
-            // This is a label to restart the execution when new function is called or restarted
+        // This is a label to restart the execution when new function is called or restarted
         Restart:
         Restart:
             ref var instructionsHead = ref context.Chunk.Instructions[0];
             ref var instructionsHead = ref context.Chunk.Instructions[0];
             var frameBase = context.FrameBase;
             var frameBase = context.FrameBase;
@@ -308,42 +307,19 @@ public static partial class LuaVirtualMachine
                         stack.GetWithNotifyTop(instruction.A + frameBase) = context.Closure.GetUpValue(instruction.B);
                         stack.GetWithNotifyTop(instruction.A + frameBase) = context.Closure.GetUpValue(instruction.B);
                         continue;
                         continue;
                     case OpCode.GetTabUp:
                     case OpCode.GetTabUp:
+                    case OpCode.GetTable:
                         instruction = instructionRef;
                         instruction = instructionRef;
                         stackHead = ref stack.FastGet(frameBase);
                         stackHead = ref stack.FastGet(frameBase);
                         ref readonly var vc = ref RKC(ref stackHead, ref constHead, instruction);
                         ref readonly var vc = ref RKC(ref stackHead, ref constHead, instruction);
-                        var table = context.Closure.GetUpValue(instruction.B);
-
-                        if (table.TryReadTable(out var luaTable) && luaTable.TryGetValue(vc, out var resultValue))
-                        {
-                            stack.GetWithNotifyTop(instruction.A + frameBase) = resultValue;
-                            continue;
-                        }
-
-                        if (TryGetMetaTableValue(table, vc, ref context, out var doRestart))
+                        ref readonly var vb = ref (instruction.OpCode == OpCode.GetTable ? ref Unsafe.Add(ref stackHead, instruction.UIntB) : ref context.Closure.GetUpValueRef(instruction.B));
+                        var doRestart = false;
+                        if (vb.TryReadTable(out var luaTable) && luaTable.TryGetValue(vc, out var resultValue) || GetTableValueSlowPath(vb, vc, ref context, out resultValue, out doRestart))
                         {
                         {
                             if (doRestart) goto Restart;
                             if (doRestart) goto Restart;
-                            continue;
-                        }
-
-                        postOperation = PostOperationType.SetResult;
-                        return true;
-                    case OpCode.GetTable:
-                        instruction = instructionRef;
-                        stackHead = ref stack.FastGet(frameBase);
-                        ref readonly var vb = ref Unsafe.Add(ref stackHead, instruction.UIntB);
-                        vc = ref RKC(ref stackHead, ref constHead, instruction);
-                        if (vb.TryReadTable(out luaTable) && luaTable.TryGetValue(vc, out resultValue))
-                        {
                             stack.GetWithNotifyTop(instruction.A + frameBase) = resultValue;
                             stack.GetWithNotifyTop(instruction.A + frameBase) = resultValue;
                             continue;
                             continue;
                         }
                         }
 
 
-                        if (TryGetMetaTableValue(vb, vc, ref context, out doRestart))
-                        {
-                            if (doRestart) goto Restart;
-                            continue;
-                        }
-
                         postOperation = PostOperationType.SetResult;
                         postOperation = PostOperationType.SetResult;
                         return true;
                         return true;
                     case OpCode.SetTabUp:
                     case OpCode.SetTabUp:
@@ -359,16 +335,20 @@ public static partial class LuaVirtualMachine
                             }
                             }
                         }
                         }
 
 
-                        table = context.Closure.GetUpValue(instruction.A);
+                        var table = context.Closure.GetUpValue(instruction.A);
 
 
                         if (table.TryReadTable(out luaTable))
                         if (table.TryReadTable(out luaTable))
                         {
                         {
-                            luaTable[vb] = RKC(ref stackHead, ref constHead, instruction);
-                            continue;
+                            ref var valueRef = ref luaTable.FindValue(vb);
+                            if (!Unsafe.IsNullRef(ref valueRef) && valueRef.Type != LuaValueType.Nil)
+                            {
+                                valueRef = RKC(ref stackHead, ref constHead, instruction);
+                                continue;
+                            }
                         }
                         }
 
 
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
-                        if (TrySetMetaTableValue(table, vb, vc, ref context, out doRestart))
+                        if (SetTableValueSlowPath(table, vb, vc, ref context, out doRestart))
                         {
                         {
                             if (doRestart) goto Restart;
                             if (doRestart) goto Restart;
                             continue;
                             continue;
@@ -399,15 +379,16 @@ public static partial class LuaVirtualMachine
 
 
                         if (table.TryReadTable(out luaTable))
                         if (table.TryReadTable(out luaTable))
                         {
                         {
-                            if (luaTable.Metatable == null || !luaTable.Metatable.ContainsKey(Metamethods.NewIndex) || luaTable.ContainsKey(vb))
+                            ref var valueRef = ref luaTable.FindValue(vb);
+                            if (!Unsafe.IsNullRef(ref valueRef) && valueRef.Type != LuaValueType.Nil)
                             {
                             {
-                                luaTable[vb] = RKC(ref stackHead, ref constHead, instruction);
+                                valueRef = RKC(ref stackHead, ref constHead, instruction);
                                 continue;
                                 continue;
                             }
                             }
                         }
                         }
 
 
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
-                        if (TrySetMetaTableValue(table, vb, vc, ref context, out doRestart))
+                        if (SetTableValueSlowPath(table, vb, vc, ref context, out doRestart))
                         {
                         {
                             if (doRestart) goto Restart;
                             if (doRestart) goto Restart;
                             continue;
                             continue;
@@ -425,21 +406,17 @@ public static partial class LuaVirtualMachine
                         stackHead = ref stack.FastGet(frameBase);
                         stackHead = ref stack.FastGet(frameBase);
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
                         vc = ref RKC(ref stackHead, ref constHead, instruction);
                         table = Unsafe.Add(ref stackHead, instruction.UIntB);
                         table = Unsafe.Add(ref stackHead, instruction.UIntB);
-                        if (table.TryReadTable(out luaTable) && luaTable.TryGetValue(vc, out resultValue))
+
+                        doRestart = false;
+                        if ((table.TryReadTable(out luaTable) && luaTable.TryGetValue(vc, out resultValue)) || GetTableValueSlowPath(table, vc, ref context, out resultValue, out doRestart))
                         {
                         {
+                            if (doRestart) goto Restart;
                             Unsafe.Add(ref stackHead, iA) = resultValue;
                             Unsafe.Add(ref stackHead, iA) = resultValue;
                             Unsafe.Add(ref stackHead, iA + 1) = table;
                             Unsafe.Add(ref stackHead, iA + 1) = table;
                             stack.NotifyTop(iA + frameBase + 2);
                             stack.NotifyTop(iA + frameBase + 2);
                             continue;
                             continue;
                         }
                         }
 
 
-
-                        if (TryGetMetaTableValue(table, vc, ref context, out doRestart))
-                        {
-                            if (doRestart) goto Restart;
-                            continue;
-                        }
-
                         postOperation = PostOperationType.Self;
                         postOperation = PostOperationType.Self;
                         return true;
                         return true;
                     case OpCode.Add:
                     case OpCode.Add:
@@ -1315,126 +1292,176 @@ public static partial class LuaVirtualMachine
     }
     }
 
 
     [MethodImpl(MethodImplOptions.NoInlining)]
     [MethodImpl(MethodImplOptions.NoInlining)]
-    static bool TryGetMetaTableValue(LuaValue table, LuaValue key, ref VirtualMachineExecutionContext context, out bool doRestart)
+    static bool GetTableValueSlowPath(LuaValue table, LuaValue key, ref VirtualMachineExecutionContext context, out LuaValue value, out bool doRestart)
     {
     {
-        var isSelf = context.Instruction.OpCode == OpCode.Self;
+        var targetTable = table;
+        const int MAX_LOOP = 100;
         doRestart = false;
         doRestart = false;
-        var state = context.State;
-        if (table.TryGetMetamethod(state, Metamethods.Index, out var metamethod))
+        var skip = targetTable.Type == LuaValueType.Table;
+        for (int i = 0; i < MAX_LOOP; i++)
         {
         {
-            if (!metamethod.TryReadFunction(out var indexTable))
+            if (table.TryReadTable(out var luaTable))
             {
             {
-                LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "call", metamethod);
-            }
-
-            var stack = context.Stack;
-            stack.Push(table);
-            stack.Push(key);
-
+                if (!skip && luaTable.TryGetValue(key, out value))
+                {
+                    return true;
+                }
 
 
-            var newFrame = indexTable.CreateNewFrame(ref context, stack.Count - 2);
+                skip = false;
 
 
-            context.Thread.PushCallStackFrame(newFrame);
+                var metatable = luaTable.Metatable;
+                if (metatable != null && metatable.TryGetValue(Metamethods.Index, out table))
+                {
+                    goto Function;
+                }
 
 
-            if (indexTable is Closure)
-            {
-                context.Push(newFrame);
-                doRestart = true;
+                value = default;
                 return true;
                 return true;
             }
             }
 
 
-            var task = indexTable.Invoke(ref context, newFrame, 2);
-
-            if (!task.IsCompleted)
+            if (!table.TryGetMetamethod(context.State, Metamethods.Index, out var metatableValue))
             {
             {
-                context.Task = task;
-                return false;
+                LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "index", table);
             }
             }
 
 
-            var awaiter = task.GetAwaiter();
-            context.Thread.PopCallStackFrame();
-            var ra = context.Instruction.A + context.FrameBase;
-            var resultCount = awaiter.GetResult();
-            context.Stack.Get(ra) = resultCount == 0 ? default : context.ResultsBuffer[0];
-            if (isSelf)
-            {
-                context.Stack.Get(ra + 1) = table;
-                context.Stack.NotifyTop(ra + 2);
-            }
-            else
+            table = metatableValue;
+        Function:
+            if (table.TryReadFunction(out var function))
             {
             {
-                context.Stack.NotifyTop(ra + 1);
+                return CallGetTableFunc(targetTable, function, key, ref context, out value, out doRestart);
             }
             }
-
-            context.ClearResultsBuffer(resultCount);
-            return true;
         }
         }
 
 
-        if (table.Type == LuaValueType.Table)
-        {
-            var ra = context.Instruction.A + context.FrameBase;
-            context.Stack.Get(ra) = default;
-            if (isSelf)
-            {
-                context.Stack.Get(ra + 1) = table;
-                context.Stack.NotifyTop(ra + 2);
-            }
-            else
-            {
-                context.Stack.NotifyTop(ra + 1);
-            }
+        throw new LuaRuntimeException(GetTracebacks(ref context), "loop in gettable");
+    }
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static bool CallGetTableFunc(LuaValue table, LuaFunction indexTable, LuaValue key, ref VirtualMachineExecutionContext context, out LuaValue result, out bool doRestart)
+    {
+        doRestart = false;
+        var stack = context.Stack;
+        stack.Push(table);
+        stack.Push(key);
+        var newFrame = indexTable.CreateNewFrame(ref context, stack.Count - 2);
 
 
+        context.Thread.PushCallStackFrame(newFrame);
+
+        if (indexTable is Closure)
+        {
+            context.Push(newFrame);
+            doRestart = true;
+            result = default;
             return true;
             return true;
         }
         }
 
 
+        var task = indexTable.Invoke(ref context, newFrame, 2);
 
 
-        LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "index", table);
-        return false;
+        if (!task.IsCompleted)
+        {
+            context.Task = task;
+            result = default;
+            return false;
+        }
+
+        var awaiter = task.GetAwaiter();
+        context.Thread.PopCallStackFrame();
+        var resultCount = awaiter.GetResult();
+        result = resultCount == 0 ? default : context.ResultsBuffer[0];
+        context.ClearResultsBuffer(resultCount);
+        return true;
     }
     }
 
 
-    static bool TrySetMetaTableValue(LuaValue table, LuaValue key, LuaValue value,
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static bool SetTableValueSlowPath(LuaValue table, LuaValue key, LuaValue value,
         ref VirtualMachineExecutionContext context, out bool doRestart)
         ref VirtualMachineExecutionContext context, out bool doRestart)
     {
     {
+        var targetTable = table;
+        const int MAX_LOOP = 100;
         doRestart = false;
         doRestart = false;
-        var state = context.State;
-        if (table.TryGetMetamethod(state, Metamethods.NewIndex, out var metamethod))
+        var skip = targetTable.Type == LuaValueType.Table;
+        for (int i = 0; i < MAX_LOOP; i++)
         {
         {
-            if (!metamethod.TryReadFunction(out var indexTable))
+            if (table.TryReadTable(out var luaTable))
             {
             {
-                LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "call", metamethod);
-            }
+                ref var valueRef = ref (skip ? ref Unsafe.NullRef<LuaValue>() : ref luaTable.FindValue(key));
+                skip = false;
+                if (!Unsafe.IsNullRef(ref valueRef) && valueRef.Type != LuaValueType.Nil)
+                {
+                    valueRef = value;
+                    return true;
+                }
 
 
-            var thread = context.Thread;
-            var stack = thread.Stack;
-            stack.Push(table);
-            stack.Push(key);
-            stack.Push(value);
-            var newFrame = indexTable.CreateNewFrame(ref context, stack.Count - 3);
+                var metatable = luaTable.Metatable;
+                if (metatable == null || !metatable.TryGetValue(Metamethods.NewIndex, out table))
+                {
+                    if (Unsafe.IsNullRef(ref valueRef))
+                    {
+                        luaTable[key] = value;
+                        return true;
+                    }
 
 
-            context.Thread.PushCallStackFrame(newFrame);
+                    valueRef = value;
+                    return true;
+                }
+
+                goto Function;
+            }
 
 
-            if (indexTable is Closure)
+            if (!table.TryGetMetamethod(context.State, Metamethods.NewIndex, out var metatableValue))
             {
             {
-                context.Push(newFrame);
-                doRestart = true;
-                return true;
+                LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "index", table);
             }
             }
 
 
-            var task = indexTable.Invoke(ref context, newFrame, 3);
-            if (!task.IsCompleted)
+            table = metatableValue;
+
+        Function:
+            if (table.TryReadFunction(out var function))
             {
             {
-                context.Task = task;
-                return false;
+                return CallSetTableFunc(targetTable, function, key, value, ref context, out doRestart);
             }
             }
+        }
+
+        throw new LuaRuntimeException(GetTracebacks(ref context), "loop in settable");
+    }
+
+    [MethodImpl(MethodImplOptions.NoInlining)]
+    static bool CallSetTableFunc(LuaValue table, LuaFunction newIndexFunction, LuaValue key, LuaValue value, ref VirtualMachineExecutionContext context, out bool doRestart)
+    {
+        doRestart = false;
+        var thread = context.Thread;
+        var stack = thread.Stack;
+        stack.Push(table);
+        stack.Push(key);
+        stack.Push(value);
+        var newFrame = newIndexFunction.CreateNewFrame(ref context, stack.Count - 3);
 
 
-            thread.PopCallStackFrame();
+        context.Thread.PushCallStackFrame(newFrame);
+
+        if (newIndexFunction is Closure)
+        {
+            context.Push(newFrame);
+            doRestart = true;
             return true;
             return true;
         }
         }
 
 
-        LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "index", table);
-        return false;
+        var task = newIndexFunction.Invoke(ref context, newFrame, 3);
+        if (!task.IsCompleted)
+        {
+            context.Task = task;
+            return false;
+        }
+
+        var resultCount = task.GetAwaiter().GetResult();
+        if (0 < resultCount)
+        {
+            context.ClearResultsBuffer(resultCount);
+        }
+
+        thread.PopCallStackFrame();
+        return true;
     }
     }
 
 
+
     [MethodImpl(MethodImplOptions.NoInlining)]
     [MethodImpl(MethodImplOptions.NoInlining)]
     static bool ExecuteBinaryOperationMetaMethod(LuaValue vb, LuaValue vc,
     static bool ExecuteBinaryOperationMetaMethod(LuaValue vb, LuaValue vc,
         ref VirtualMachineExecutionContext context, string name, string description, out bool doRestart)
         ref VirtualMachineExecutionContext context, string name, string description, out bool doRestart)

+ 13 - 0
src/Lua/Runtime/UpValue.cs

@@ -45,6 +45,19 @@ public sealed class UpValue
             return Thread!.Stack.Get(RegisterIndex);
             return Thread!.Stack.Get(RegisterIndex);
         }
         }
     }
     }
+    
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    internal ref readonly LuaValue GetValueRef()
+    {
+        if (IsClosed)
+        {
+            return ref value;
+        }
+        else
+        {
+            return ref Thread!.Stack.Get(RegisterIndex);
+        }
+    }
 
 
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
     [MethodImpl(MethodImplOptions.AggressiveInlining)]
     public void SetValue(LuaValue value)
     public void SetValue(LuaValue value)

+ 40 - 0
tests/Lua.Tests/MetatableTests.cs

@@ -46,4 +46,44 @@ return a + b
             Assert.That(table[3].Read<double>(), Is.EqualTo(9));
             Assert.That(table[3].Read<double>(), Is.EqualTo(9));
         });
         });
     }
     }
+
+    [Test]
+    public async Task Test_Metamethod_Index()
+    {
+        var source = @"
+metatable = {
+    __index = {x=1}
+}
+
+local a = {}
+setmetatable(a, metatable)
+assert(a.x == 1)
+metatable.__index= nil
+assert(a.x == nil)
+metatable.__index= function(a,b) return b end
+assert(a.x == 'x')
+";
+        await state.DoStringAsync(source);
+    }
+
+    [Test]
+    public async Task Test_Metamethod_NewIndex()
+    {
+        var source = @"
+metatable = {
+    __newindex = {}
+}
+
+local a = {}
+a.x = 1
+setmetatable(a, metatable)
+a.x = 2
+assert(a.x == 2)
+a.x = nil
+a.x = 2
+assert(a.x == nil)
+assert(metatable.__newindex.x == 2)
+";
+        await state.DoStringAsync(source);
+    }
 }
 }