2
0
AnnulusGames 1 жил өмнө
parent
commit
a6d2f53ff3

+ 59 - 0
src/Lua/LuaState.cs

@@ -1,3 +1,4 @@
+using System.Diagnostics.CodeAnalysis;
 using Lua.Internal;
 using Lua.Loaders;
 using Lua.Runtime;
@@ -8,6 +9,7 @@ public sealed class LuaState
 {
     public const string DefaultChunkName = "chunk";
 
+    // states
     readonly LuaMainThread mainThread = new();
     FastListCore<UpValue> openUpValues;
     FastStackCore<LuaThread> threadStack;
@@ -32,6 +34,14 @@ public sealed class LuaState
 
     public ILuaModuleLoader ModuleLoader { get; set; } = FileModuleLoader.Instance;
 
+    // metatables
+    LuaTable? nilMetatable;
+    LuaTable? numberMetatable;
+    LuaTable? stringMetatable;
+    LuaTable? booleanMetatable;
+    LuaTable? functionMetatable;
+    LuaTable? threadMetatable;
+
     public static LuaState Create()
     {
         return new();
@@ -89,6 +99,55 @@ public sealed class LuaState
         };
     }
 
+    internal bool TryGetMetatable(LuaValue value, [NotNullWhen(true)] out LuaTable? result)
+    {
+        result = value.Type switch
+        {
+            LuaValueType.Nil => nilMetatable,
+            LuaValueType.Boolean => booleanMetatable,
+            LuaValueType.String => stringMetatable,
+            LuaValueType.Number => numberMetatable,
+            LuaValueType.Function => functionMetatable,
+            LuaValueType.Thread => threadMetatable,
+            LuaValueType.UserData => value.Read<LuaUserData>().Metatable,
+            LuaValueType.Table => value.Read<LuaTable>().Metatable,
+            _ => null
+        };
+
+        return result != null;
+    }
+
+    internal void SetMetatable(LuaValue value, LuaTable metatable)
+    {
+        switch (value.Type)
+        {
+            case LuaValueType.Nil:
+                nilMetatable = metatable;
+                break;
+            case LuaValueType.Boolean:
+                booleanMetatable = metatable;
+                break;
+            case LuaValueType.String:
+                stringMetatable = metatable;
+                break;
+            case LuaValueType.Number:
+                numberMetatable = metatable;
+                break;
+            case LuaValueType.Function:
+                functionMetatable = metatable;
+                break;
+            case LuaValueType.Thread:
+                threadMetatable = metatable;
+                break;
+            case LuaValueType.UserData:
+                value.Read<LuaUserData>().Metatable = metatable;
+                break;
+            case LuaValueType.Table:
+                value.Read<LuaTable>().Metatable = metatable;
+                break;
+        }
+    }
+
     internal UpValue GetOrAddUpValue(LuaThread thread, int registerIndex)
     {
         foreach (var upValue in openUpValues.AsSpan())

+ 1 - 1
src/Lua/LuaValue.cs

@@ -369,7 +369,7 @@ public readonly struct LuaValue : IEquatable<LuaValue>
 
     internal async ValueTask<int> CallToStringAsync(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
-        if (this.TryGetMetamethod(Metamethods.ToString, out var metamethod))
+        if (this.TryGetMetamethod(context.State, Metamethods.ToString, out var metamethod))
         {
             if (!metamethod.TryRead<LuaFunction>(out var func))
             {

+ 4 - 12
src/Lua/Runtime/LuaValueRuntimeExtensions.cs

@@ -5,19 +5,11 @@ namespace Lua.Runtime;
 
 internal static class LuaRuntimeExtensions
 {
-    public static bool TryGetMetamethod(this LuaValue value, string methodName, out LuaValue result)
+    public static bool TryGetMetamethod(this LuaValue value, LuaState state, string methodName, out LuaValue result)
     {
-        if (value.TryRead<LuaTable>(out var table) &&
-            table.Metatable != null &&
-            table.Metatable.TryGetValue(methodName, out result))
-        {
-            return true;
-        }
-        else
-        {
-            result = default;
-            return false;
-        }
+        result = default;
+        return state.TryGetMetatable(value, out var metatable) &&
+            metatable.TryGetValue(methodName, out result);
     }
 
 #if NET6_0_OR_GREATER

+ 16 - 16
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -123,7 +123,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = valueB + valueC;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Add, out var metamethod) || vc.TryGetMetamethod(Metamethods.Add, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Add, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Add, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -162,7 +162,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = valueB - valueC;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Sub, out var metamethod) || vc.TryGetMetamethod(Metamethods.Sub, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Sub, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Sub, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -201,7 +201,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = valueB * valueC;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Mul, out var metamethod) || vc.TryGetMetamethod(Metamethods.Mul, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Mul, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Mul, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -240,7 +240,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = valueB / valueC;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Div, out var metamethod) || vc.TryGetMetamethod(Metamethods.Div, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Div, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Div, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -279,7 +279,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = valueB % valueC;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Mod, out var metamethod) || vc.TryGetMetamethod(Metamethods.Mod, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Mod, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Mod, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -318,7 +318,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = Math.Pow(valueB, valueC);
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Pow, out var metamethod) || vc.TryGetMetamethod(Metamethods.Pow, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Pow, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Pow, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -356,7 +356,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = -valueB;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Unm, out var metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Unm, out var metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -411,7 +411,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = str.Length;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Len, out var metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Len, out var metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -467,7 +467,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = strB + strC;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Concat, out var metamethod) || vc.TryGetMetamethod(Metamethods.Concat, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Concat, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Concat, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -508,7 +508,7 @@ public static partial class LuaVirtualMachine
                         var vc = RK(stack, chunk, instruction.C, frame.Base);
                         var compareResult = vb == vc;
 
-                        if (!compareResult && (vb.TryGetMetamethod(Metamethods.Eq, out var metamethod) || vc.TryGetMetamethod(Metamethods.Eq, out metamethod)))
+                        if (!compareResult && (vb.TryGetMetamethod(state, Metamethods.Eq, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Eq, out metamethod)))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -553,7 +553,7 @@ public static partial class LuaVirtualMachine
                         {
                             compareResult = valueB < valueC;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Lt, out var metamethod) || vc.TryGetMetamethod(Metamethods.Lt, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Lt, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Lt, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -602,7 +602,7 @@ public static partial class LuaVirtualMachine
                         {
                             compareResult = valueB <= valueC;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Le, out var metamethod) || vc.TryGetMetamethod(Metamethods.Le, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Le, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Le, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -667,7 +667,7 @@ public static partial class LuaVirtualMachine
                         var va = stack.UnsafeGet(RA);
                         if (!va.TryRead<LuaFunction>(out var func))
                         {
-                            if (va.TryGetMetamethod(Metamethods.Call, out var metamethod) && metamethod.TryRead<LuaFunction>(out func))
+                            if (va.TryGetMetamethod(state, Metamethods.Call, out var metamethod) && metamethod.TryRead<LuaFunction>(out func))
                             {
                             }
                             else
@@ -717,7 +717,7 @@ public static partial class LuaVirtualMachine
                         var va = stack.UnsafeGet(RA);
                         if (!va.TryRead<LuaFunction>(out var func))
                         {
-                            if (!va.TryGetMetamethod(Metamethods.Call, out var metamethod) && !metamethod.TryRead<LuaFunction>(out func))
+                            if (!va.TryGetMetamethod(state, Metamethods.Call, out var metamethod) && !metamethod.TryRead<LuaFunction>(out func))
                             {
                                 LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(state, chunk, pc), "call", metamethod);
                             }
@@ -888,7 +888,7 @@ public static partial class LuaVirtualMachine
         {
             return result;
         }
-        else if (table.TryGetMetamethod(Metamethods.Index, out var metamethod))
+        else if (table.TryGetMetamethod(state, Metamethods.Index, out var metamethod))
         {
             if (!metamethod.TryRead<LuaFunction>(out var indexTable))
             {
@@ -939,7 +939,7 @@ public static partial class LuaVirtualMachine
         {
             t[key] = value;
         }
-        else if (table.TryGetMetamethod(Metamethods.NewIndex, out var metamethod))
+        else if (table.TryGetMetamethod(state, Metamethods.NewIndex, out var metamethod))
         {
             if (!metamethod.TryRead<LuaFunction>(out var indexTable))
             {