Browse Source

Merge pull request #13 from AnnulusGames/type-metatable

Add metatable for type
Annulus Games 1 year ago
parent
commit
59f94a4554

+ 59 - 0
src/Lua/LuaState.cs

@@ -1,3 +1,4 @@
+using System.Diagnostics.CodeAnalysis;
 using Lua.Internal;
 using Lua.Loaders;
 using Lua.Runtime;
@@ -8,6 +9,7 @@ public sealed class LuaState
 {
     public const string DefaultChunkName = "chunk";
 
+    // states
     readonly LuaMainThread mainThread = new();
     FastListCore<UpValue> openUpValues;
     FastStackCore<LuaThread> threadStack;
@@ -32,6 +34,14 @@ public sealed class LuaState
 
     public ILuaModuleLoader ModuleLoader { get; set; } = FileModuleLoader.Instance;
 
+    // metatables
+    LuaTable? nilMetatable;
+    LuaTable? numberMetatable;
+    LuaTable? stringMetatable;
+    LuaTable? booleanMetatable;
+    LuaTable? functionMetatable;
+    LuaTable? threadMetatable;
+
     public static LuaState Create()
     {
         return new();
@@ -89,6 +99,55 @@ public sealed class LuaState
         };
     }
 
+    internal bool TryGetMetatable(LuaValue value, [NotNullWhen(true)] out LuaTable? result)
+    {
+        result = value.Type switch
+        {
+            LuaValueType.Nil => nilMetatable,
+            LuaValueType.Boolean => booleanMetatable,
+            LuaValueType.String => stringMetatable,
+            LuaValueType.Number => numberMetatable,
+            LuaValueType.Function => functionMetatable,
+            LuaValueType.Thread => threadMetatable,
+            LuaValueType.UserData => value.Read<LuaUserData>().Metatable,
+            LuaValueType.Table => value.Read<LuaTable>().Metatable,
+            _ => null
+        };
+
+        return result != null;
+    }
+
+    internal void SetMetatable(LuaValue value, LuaTable metatable)
+    {
+        switch (value.Type)
+        {
+            case LuaValueType.Nil:
+                nilMetatable = metatable;
+                break;
+            case LuaValueType.Boolean:
+                booleanMetatable = metatable;
+                break;
+            case LuaValueType.String:
+                stringMetatable = metatable;
+                break;
+            case LuaValueType.Number:
+                numberMetatable = metatable;
+                break;
+            case LuaValueType.Function:
+                functionMetatable = metatable;
+                break;
+            case LuaValueType.Thread:
+                threadMetatable = metatable;
+                break;
+            case LuaValueType.UserData:
+                value.Read<LuaUserData>().Metatable = metatable;
+                break;
+            case LuaValueType.Table:
+                value.Read<LuaTable>().Metatable = metatable;
+                break;
+        }
+    }
+
     internal UpValue GetOrAddUpValue(LuaThread thread, int registerIndex)
     {
         foreach (var upValue in openUpValues.AsSpan())

+ 11 - 0
src/Lua/LuaUserData.cs

@@ -0,0 +1,11 @@
+namespace Lua;
+
+public abstract class LuaUserData
+{
+    public LuaTable? Metatable { get; set; }
+}
+
+public class LuaUserData<T>(T value) : LuaUserData
+{
+    public T Value { get; } = value;
+}

+ 26 - 9
src/Lua/LuaValue.cs

@@ -132,12 +132,27 @@ public readonly struct LuaValue : IEquatable<LuaValue>
                     break;
                 }
             case LuaValueType.UserData:
-                if (referenceValue is T userData)
+                if (t == typeof(LuaUserData))
                 {
-                    result = userData;
+                    var v = (LuaUserData)referenceValue!;
+                    result = Unsafe.As<LuaUserData, T>(ref v);
                     return true;
                 }
-                break;
+                else if (t == typeof(LuaUserData<T>))
+                {
+                    var v = (LuaUserData<T>)referenceValue!;
+                    result = Unsafe.As<LuaUserData<T>, T>(ref v);
+                    return true;
+                }
+                else if (t == typeof(object))
+                {
+                    result = (T)referenceValue!;
+                    return true;
+                }
+                else
+                {
+                    break;
+                }
             case LuaValueType.Table:
                 if (t == typeof(LuaTable))
                 {
@@ -210,7 +225,7 @@ public readonly struct LuaValue : IEquatable<LuaValue>
         referenceValue = value;
     }
 
-    public LuaValue(object? value)
+    public LuaValue(LuaUserData value)
     {
         type = LuaValueType.UserData;
         referenceValue = value;
@@ -246,6 +261,11 @@ public readonly struct LuaValue : IEquatable<LuaValue>
         return new(value);
     }
 
+    public static implicit operator LuaValue(LuaUserData value)
+    {
+        return new(value);
+    }
+
     public override int GetHashCode()
     {
         var valueHash = type switch
@@ -254,10 +274,7 @@ public readonly struct LuaValue : IEquatable<LuaValue>
             LuaValueType.Boolean => Read<bool>().GetHashCode(),
             LuaValueType.String => Read<string>().GetHashCode(),
             LuaValueType.Number => Read<double>().GetHashCode(),
-            LuaValueType.Function => Read<LuaFunction>().GetHashCode(),
-            LuaValueType.Thread => Read<LuaThread>().GetHashCode(),
-            LuaValueType.Table => Read<LuaTable>().GetHashCode(),
-            LuaValueType.UserData => referenceValue == null ? 0 : referenceValue.GetHashCode(),
+            LuaValueType.Function or LuaValueType.Thread or LuaValueType.Table or LuaValueType.UserData => referenceValue!.GetHashCode(),
             _ => 0,
         };
 
@@ -352,7 +369,7 @@ public readonly struct LuaValue : IEquatable<LuaValue>
 
     internal async ValueTask<int> CallToStringAsync(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
-        if (this.TryGetMetamethod(Metamethods.ToString, out var metamethod))
+        if (this.TryGetMetamethod(context.State, Metamethods.ToString, out var metamethod))
         {
             if (!metamethod.TryRead<LuaFunction>(out var func))
             {

+ 4 - 12
src/Lua/Runtime/LuaValueRuntimeExtensions.cs

@@ -5,19 +5,11 @@ namespace Lua.Runtime;
 
 internal static class LuaRuntimeExtensions
 {
-    public static bool TryGetMetamethod(this LuaValue value, string methodName, out LuaValue result)
+    public static bool TryGetMetamethod(this LuaValue value, LuaState state, string methodName, out LuaValue result)
     {
-        if (value.TryRead<LuaTable>(out var table) &&
-            table.Metatable != null &&
-            table.Metatable.TryGetValue(methodName, out result))
-        {
-            return true;
-        }
-        else
-        {
-            result = default;
-            return false;
-        }
+        result = default;
+        return state.TryGetMetatable(value, out var metatable) &&
+            metatable.TryGetValue(methodName, out result);
     }
 
 #if NET6_0_OR_GREATER

+ 16 - 16
src/Lua/Runtime/LuaVirtualMachine.cs

@@ -123,7 +123,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = valueB + valueC;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Add, out var metamethod) || vc.TryGetMetamethod(Metamethods.Add, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Add, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Add, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -162,7 +162,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = valueB - valueC;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Sub, out var metamethod) || vc.TryGetMetamethod(Metamethods.Sub, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Sub, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Sub, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -201,7 +201,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = valueB * valueC;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Mul, out var metamethod) || vc.TryGetMetamethod(Metamethods.Mul, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Mul, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Mul, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -240,7 +240,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = valueB / valueC;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Div, out var metamethod) || vc.TryGetMetamethod(Metamethods.Div, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Div, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Div, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -279,7 +279,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = valueB % valueC;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Mod, out var metamethod) || vc.TryGetMetamethod(Metamethods.Mod, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Mod, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Mod, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -318,7 +318,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = Math.Pow(valueB, valueC);
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Pow, out var metamethod) || vc.TryGetMetamethod(Metamethods.Pow, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Pow, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Pow, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -356,7 +356,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = -valueB;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Unm, out var metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Unm, out var metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -411,7 +411,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = str.Length;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Len, out var metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Len, out var metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -467,7 +467,7 @@ public static partial class LuaVirtualMachine
                         {
                             stack.UnsafeGet(RA) = strB + strC;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Concat, out var metamethod) || vc.TryGetMetamethod(Metamethods.Concat, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Concat, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Concat, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -508,7 +508,7 @@ public static partial class LuaVirtualMachine
                         var vc = RK(stack, chunk, instruction.C, frame.Base);
                         var compareResult = vb == vc;
 
-                        if (!compareResult && (vb.TryGetMetamethod(Metamethods.Eq, out var metamethod) || vc.TryGetMetamethod(Metamethods.Eq, out metamethod)))
+                        if (!compareResult && (vb.TryGetMetamethod(state, Metamethods.Eq, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Eq, out metamethod)))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -553,7 +553,7 @@ public static partial class LuaVirtualMachine
                         {
                             compareResult = valueB < valueC;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Lt, out var metamethod) || vc.TryGetMetamethod(Metamethods.Lt, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Lt, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Lt, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -602,7 +602,7 @@ public static partial class LuaVirtualMachine
                         {
                             compareResult = valueB <= valueC;
                         }
-                        else if (vb.TryGetMetamethod(Metamethods.Le, out var metamethod) || vc.TryGetMetamethod(Metamethods.Le, out metamethod))
+                        else if (vb.TryGetMetamethod(state, Metamethods.Le, out var metamethod) || vc.TryGetMetamethod(state, Metamethods.Le, out metamethod))
                         {
                             if (!metamethod.TryRead<LuaFunction>(out var func))
                             {
@@ -667,7 +667,7 @@ public static partial class LuaVirtualMachine
                         var va = stack.UnsafeGet(RA);
                         if (!va.TryRead<LuaFunction>(out var func))
                         {
-                            if (va.TryGetMetamethod(Metamethods.Call, out var metamethod) && metamethod.TryRead<LuaFunction>(out func))
+                            if (va.TryGetMetamethod(state, Metamethods.Call, out var metamethod) && metamethod.TryRead<LuaFunction>(out func))
                             {
                             }
                             else
@@ -717,7 +717,7 @@ public static partial class LuaVirtualMachine
                         var va = stack.UnsafeGet(RA);
                         if (!va.TryRead<LuaFunction>(out var func))
                         {
-                            if (!va.TryGetMetamethod(Metamethods.Call, out var metamethod) && !metamethod.TryRead<LuaFunction>(out func))
+                            if (!va.TryGetMetamethod(state, Metamethods.Call, out var metamethod) && !metamethod.TryRead<LuaFunction>(out func))
                             {
                                 LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(state, chunk, pc), "call", metamethod);
                             }
@@ -888,7 +888,7 @@ public static partial class LuaVirtualMachine
         {
             return result;
         }
-        else if (table.TryGetMetamethod(Metamethods.Index, out var metamethod))
+        else if (table.TryGetMetamethod(state, Metamethods.Index, out var metamethod))
         {
             if (!metamethod.TryRead<LuaFunction>(out var indexTable))
             {
@@ -939,7 +939,7 @@ public static partial class LuaVirtualMachine
         {
             t[key] = value;
         }
-        else if (table.TryGetMetamethod(Metamethods.NewIndex, out var metamethod))
+        else if (table.TryGetMetamethod(state, Metamethods.NewIndex, out var metamethod))
         {
             if (!metamethod.TryRead<LuaFunction>(out var indexTable))
             {

+ 1 - 1
src/Lua/Standard/Mathematics/RandomFunction.cs

@@ -10,7 +10,7 @@ public sealed class RandomFunction : LuaFunction
 
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
-        var rand = context.State.Environment[RandomInstanceKey].Read<Random>();
+        var rand = context.State.Environment[RandomInstanceKey].Read<LuaUserData<Random>>().Value;
 
         if (context.ArgumentCount == 0)
         {

+ 1 - 1
src/Lua/Standard/Mathematics/RandomSeedFunction.cs

@@ -10,7 +10,7 @@ public sealed class RandomSeedFunction : LuaFunction
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
         var arg0 = context.ReadArgument<double>(0);
-        context.State.Environment[RandomFunction.RandomInstanceKey] = new(new Random((int)BitConverter.DoubleToInt64Bits(arg0)));
+        context.State.Environment[RandomFunction.RandomInstanceKey] = new LuaUserData<Random>(new Random((int)BitConverter.DoubleToInt64Bits(arg0)));
         return new(0);
     }
 }

+ 1 - 1
src/Lua/Standard/OpenLibExtensions.cs

@@ -96,7 +96,7 @@ public static class OpenLibExtensions
 
     public static void OpenMathLibrary(this LuaState state)
     {
-        state.Environment[RandomFunction.RandomInstanceKey] = new(new Random());
+        state.Environment[RandomFunction.RandomInstanceKey] = new LuaUserData<Random>(new Random());
         state.Environment["pi"] = Math.PI;
         state.Environment["huge"] = double.PositiveInfinity;