Browse Source

Refactoring

AnnulusGames 1 year ago
parent
commit
685f67a408

+ 1 - 6
src/Lua/Exceptions.cs

@@ -68,12 +68,7 @@ public class LuaRuntimeException(Tracebacks tracebacks, string message) : LuaExc
         throw new LuaRuntimeException(tracebacks, $"bad argument #{argumentId} to '{functionName}' ({string.Join(" or ", expected)} expected)");
     }
 
-    public static void BadArgument(Tracebacks tracebacks, int argumentId, string functionName, LuaValueType expected)
-    {
-        throw new LuaRuntimeException(tracebacks, $"bad argument #{argumentId} to '{functionName}' ({expected} expected, got no value)");
-    }
-
-    public static void BadArgument(Tracebacks tracebacks, int argumentId, string functionName, LuaValueType expected, LuaValueType actual)
+    public static void BadArgument(Tracebacks tracebacks, int argumentId, string functionName, string expected, string actual)
     {
         throw new LuaRuntimeException(tracebacks, $"bad argument #{argumentId} to '{functionName}' ({expected} expected, got {actual})");
     }

+ 30 - 3
src/Lua/LuaFunction.cs

@@ -28,13 +28,40 @@ public abstract partial class LuaFunction
         }
     }
 
+    public virtual string Name => GetType().Name;
     protected abstract ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken);
 
-    protected static void ThrowIfArgumentNotExists(LuaFunctionExecutionContext context, string chunkName, int index)
+    protected void ThrowIfArgumentNotExists(LuaFunctionExecutionContext context, int index)
     {
-        if (context.ArgumentCount == index)
+        if (context.ArgumentCount <= index)
         {
-            LuaRuntimeException.BadArgument(context.State.GetTracebacks(), index + 1, chunkName);
+            LuaRuntimeException.BadArgument(context.State.GetTracebacks(), index + 1, Name);
         }
     }
+
+    protected LuaValue ReadArgument(LuaFunctionExecutionContext context, int index)
+    {
+        ThrowIfArgumentNotExists(context, index);
+        return context.Arguments[index];
+    }
+
+    protected T ReadArgument<T>(LuaFunctionExecutionContext context, int index)
+    {
+        ThrowIfArgumentNotExists(context, index);
+
+        var arg = context.Arguments[index];
+        if (!arg.TryRead<T>(out var argValue))
+        {
+            if (LuaValue.TryGetLuaValueType(typeof(T), out var type))
+            {
+                LuaRuntimeException.BadArgument(context.State.GetTracebacks(), 1, Name, type.ToString(), arg.Type.ToString());
+            }
+            else
+            {
+                LuaRuntimeException.BadArgument(context.State.GetTracebacks(), 1, Name, typeof(T).Name, arg.Type.ToString());
+            }
+        }
+
+        return argValue;
+    }
 }

+ 1 - 1
src/Lua/LuaState.cs

@@ -61,7 +61,7 @@ public sealed class LuaState
         return callStack.Peek();
     }
 
-    internal Tracebacks GetTracebacks()
+    public Tracebacks GetTracebacks()
     {
         return new()
         {

+ 55 - 0
src/Lua/LuaValue.cs

@@ -1,5 +1,6 @@
 using System.Runtime.CompilerServices;
 using System.Runtime.InteropServices;
+using Lua.Runtime;
 
 namespace Lua;
 
@@ -280,4 +281,58 @@ public readonly struct LuaValue : IEquatable<LuaValue>
             _ => "",
         };
     }
+
+    public static bool TryGetLuaValueType(Type type, out LuaValueType result)
+    {
+        if (type == typeof(double) || type == typeof(float) || type == typeof(int) || type == typeof(long))
+        {
+            result = LuaValueType.Number;
+            return true;
+        }
+        else if (type == typeof(bool))
+        {
+            result = LuaValueType.Boolean;
+            return true;
+        }
+        else if (type == typeof(string))
+        {
+            result = LuaValueType.String;
+            return true;
+        }
+        else if (type == typeof(LuaFunction) || type.IsSubclassOf(typeof(LuaFunction)))
+        {
+            result = LuaValueType.Function;
+            return true;
+        }
+        else if (type == typeof(LuaTable))
+        {
+            result = LuaValueType.Table;
+            return true;
+        }
+
+        result = default;
+        return false;
+    }
+
+    internal async ValueTask<int> CallToStringAsync(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
+    {
+        if (this.TryGetMetamethod(Metamethods.ToString, out var metamethod))
+        {
+            if (!metamethod.TryRead<LuaFunction>(out var func))
+            {
+                LuaRuntimeException.AttemptInvalidOperation(context.State.GetTracebacks(), "call", metamethod);
+            }
+
+            context.State.Push(value);
+            return await func.InvokeAsync(context with
+            {
+                ArgumentCount = 1,
+            }, buffer, cancellationToken);
+        }
+        else
+        {
+            buffer.Span[0] = value.ToString()!;
+            return 1;
+        }
+    }
 }

+ 8 - 6
src/Lua/Standard/Base/AssertFunction.cs

@@ -2,18 +2,20 @@ namespace Lua.Standard.Base;
 
 public sealed class AssertFunction : LuaFunction
 {
-    public const string Name = "assert";
+    public override string Name => "assert";
     public static readonly AssertFunction Instance = new();
 
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
-        ThrowIfArgumentNotExists(context, Name, 0);
+        var arg0 = ReadArgument(context, 0);
 
-        if (!context.Arguments[0].ToBoolean())
+        if (!arg0.ToBoolean())
         {
-            var message = context.ArgumentCount >= 2
-                ? context.Arguments[1].Read<string>()
-                : $"assertion failed!";
+            var message = "assertion failed!";
+            if (context.ArgumentCount >= 2)
+            {
+                message = ReadArgument<string>(context, 1);
+            }
 
             throw new LuaAssertionException(context.State.GetTracebacks(), message);
         }

+ 1 - 1
src/Lua/Standard/Base/ErrorFunction.cs

@@ -2,7 +2,7 @@ namespace Lua.Standard.Base;
 
 public sealed class ErrorFunction : LuaFunction
 {
-    public const string Name = "error";
+    public override string Name => "error";
     public static readonly ErrorFunction Instance = new();
 
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)

+ 2 - 2
src/Lua/Standard/Base/GetMetatableFunction.cs

@@ -5,12 +5,12 @@ namespace Lua.Standard.Base;
 
 public sealed class GetMetatableFunction : LuaFunction
 {
-    public const string Name = "getmetatable";
+    public override string Name => "getmetatable";
     public static readonly GetMetatableFunction Instance = new();
 
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
-        ThrowIfArgumentNotExists(context, Name, 0);
+        ThrowIfArgumentNotExists(context, 0);
 
         var obj = context.Arguments[0];
 

+ 2 - 5
src/Lua/Standard/Base/PrintFunction.cs

@@ -1,11 +1,8 @@
-
-using Lua.Runtime;
-
 namespace Lua.Standard.Base;
 
 public sealed class PrintFunction : LuaFunction
 {
-    public const string Name = "print";
+    public override string Name => "print";
     public static readonly PrintFunction Instance = new();
 
     LuaValue[] buffer = new LuaValue[1];
@@ -14,7 +11,7 @@ public sealed class PrintFunction : LuaFunction
     {
         for (int i = 0; i < context.ArgumentCount; i++)
         {
-            await ToStringFunction.ToStringCore(context, context.Arguments[i], this.buffer, cancellationToken);
+            await context.Arguments[i].CallToStringAsync(context, this.buffer, cancellationToken);
             Console.Write(this.buffer[0]);
             Console.Write('\t');
         }

+ 1 - 1
src/Lua/Standard/Base/RawEqualFunction.cs

@@ -3,7 +3,7 @@ namespace Lua.Standard.Base;
 
 public sealed class RawEqualFunction : LuaFunction
 {
-    public const string Name = "rawequal";
+    public override string Name => "rawequal";
     public static readonly RawEqualFunction Instance = new();
 
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)

+ 5 - 11
src/Lua/Standard/Base/RawGetFunction.cs

@@ -3,21 +3,15 @@ namespace Lua.Standard.Base;
 
 public sealed class RawGetFunction : LuaFunction
 {
-    public const string Name = "rawget";
+    public override string Name => "rawget";
     public static readonly RawGetFunction Instance = new();
 
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
-        ThrowIfArgumentNotExists(context, Name, 0);
-        ThrowIfArgumentNotExists(context, Name, 1);
-
-        var arg0 = context.Arguments[0];
-        if (!arg0.TryRead<LuaTable>(out var table))
-        {
-            LuaRuntimeException.BadArgument(context.State.GetTracebacks(), 1, Name, LuaValueType.Table, arg0.Type);
-        }
-
-        buffer.Span[0] = table[context.Arguments[1]];
+        var arg0 = ReadArgument<LuaTable>(context, 0);
+        var arg1 = ReadArgument(context, 1);
+        
+        buffer.Span[0] = arg0[arg1];
         return new(1);
     }
 }

+ 5 - 12
src/Lua/Standard/Base/RawSetFunction.cs

@@ -3,23 +3,16 @@ namespace Lua.Standard.Base;
 
 public sealed class RawSetFunction : LuaFunction
 {
-    public const string Name = "rawset";
+    public override string Name => "rawset";
     public static readonly RawSetFunction Instance = new();
 
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
-        ThrowIfArgumentNotExists(context, Name, 0);
-        ThrowIfArgumentNotExists(context, Name, 1);
-        ThrowIfArgumentNotExists(context, Name, 2);
-
-        var arg0 = context.Arguments[0];
-        if (!arg0.TryRead<LuaTable>(out var table))
-        {
-            LuaRuntimeException.BadArgument(context.State.GetTracebacks(), 1, Name, LuaValueType.Table, arg0.Type);
-        }
-
-        table[context.Arguments[1]] = context.Arguments[2];
+        var arg0 = ReadArgument<LuaTable>(context, 0);
+        var arg1 = ReadArgument(context, 1);
+        var arg2 = ReadArgument(context, 2);
 
+        arg0[arg1] = arg2;
         return new(0);
     }
 }

+ 6 - 13
src/Lua/Standard/Base/SetMetatableFunction.cs

@@ -5,19 +5,12 @@ namespace Lua.Standard.Base;
 
 public sealed class SetMetatableFunction : LuaFunction
 {
-    public const string Name = "setmetatable";
+    public override string Name => "setmetatable";
     public static readonly SetMetatableFunction Instance = new();
 
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
-        ThrowIfArgumentNotExists(context, Name, 0);
-        ThrowIfArgumentNotExists(context, Name, 1);
-
-        var arg0 = context.Arguments[0];
-        if (!arg0.TryRead<LuaTable>(out var table))
-        {
-            LuaRuntimeException.BadArgument(context.State.GetTracebacks(), 1, Name, LuaValueType.Table, arg0.Type);
-        }
+        var arg0 = ReadArgument<LuaTable>(context, 0);
 
         var arg1 = context.Arguments[1];
         if (arg1.Type is not (LuaValueType.Nil or LuaValueType.Table))
@@ -25,20 +18,20 @@ public sealed class SetMetatableFunction : LuaFunction
             LuaRuntimeException.BadArgument(context.State.GetTracebacks(), 2, Name, [LuaValueType.Nil, LuaValueType.Table]);
         }
 
-        if (table.Metatable != null && table.Metatable.TryGetValue(Metamethods.Metatable, out _))
+        if (arg0.Metatable != null && arg0.Metatable.TryGetValue(Metamethods.Metatable, out _))
         {
             throw new LuaRuntimeException(context.State.GetTracebacks(), "cannot change a protected metatable");
         }
         else if (arg1.Type is LuaValueType.Nil)
         {
-            table.Metatable = null;
+            arg0.Metatable = null;
         }
         else
         {
-            table.Metatable = arg1.Read<LuaTable>();
+            arg0.Metatable = arg1.Read<LuaTable>();
         }
 
-        buffer.Span[0] = table;
+        buffer.Span[0] = arg0;
         return new(1);
     }
 }

+ 3 - 29
src/Lua/Standard/Base/ToStringFunction.cs

@@ -1,39 +1,13 @@
-
-using System.Buffers;
-using Lua.Runtime;
-
 namespace Lua.Standard.Base;
 
 public sealed class ToStringFunction : LuaFunction
 {
-    public const string Name = "tostring";
+    public override string Name => "tostring";
     public static readonly ToStringFunction Instance = new();
 
     protected override ValueTask<int> InvokeAsyncCore(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken)
     {
-        ThrowIfArgumentNotExists(context, Name, 0);
-        return ToStringCore(context, context.Arguments[0], buffer, cancellationToken);
-    }
-
-    internal static async ValueTask<int> ToStringCore(LuaFunctionExecutionContext context, LuaValue value, Memory<LuaValue> buffer, CancellationToken cancellationToken)
-    {
-        if (value.TryGetMetamethod(Metamethods.ToString, out var metamethod))
-        {
-            if (!metamethod.TryRead<LuaFunction>(out var func))
-            {
-                LuaRuntimeException.AttemptInvalidOperation(context.State.GetTracebacks(), "call", metamethod);
-            }
-
-            context.State.Push(value);
-            return await func.InvokeAsync(context with
-            {
-                ArgumentCount = 1,
-            }, buffer, cancellationToken);
-        }
-        else
-        {
-            buffer.Span[0] = value.ToString()!;
-            return 1;
-        }
+        var arg0 = ReadArgument(context, 0);
+        return arg0.CallToStringAsync(context, buffer, cancellationToken);
     }
 }

+ 15 - 8
src/Lua/Standard/OpenLibExtensions.cs

@@ -4,17 +4,24 @@ namespace Lua.Standard;
 
 public static class OpenLibExtensions
 {
+    static readonly LuaFunction[] baseFunctions = [
+        AssertFunction.Instance,
+        ErrorFunction.Instance,
+        PrintFunction.Instance,
+        RawGetFunction.Instance,
+        RawSetFunction.Instance,
+        GetMetatableFunction.Instance,
+        SetMetatableFunction.Instance,
+        ToStringFunction.Instance
+    ];
+
     public static void OpenBaseLibrary(this LuaState state)
     {
         state.Environment["_G"] = state.Environment;
         state.Environment["_VERSION"] = "Lua 5.2";
-        state.Environment[AssertFunction.Name] = AssertFunction.Instance;
-        state.Environment[ErrorFunction.Name] = ErrorFunction.Instance;
-        state.Environment[PrintFunction.Name] = PrintFunction.Instance;
-        state.Environment[RawGetFunction.Name] = RawGetFunction.Instance;
-        state.Environment[RawSetFunction.Name] = RawSetFunction.Instance;
-        state.Environment[GetMetatableFunction.Name] = GetMetatableFunction.Instance;
-        state.Environment[SetMetatableFunction.Name] = SetMetatableFunction.Instance;
-        state.Environment[ToStringFunction.Name] = ToStringFunction.Instance;
+        foreach (var func in baseFunctions)
+        {
+            state.Environment[func.Name] = func;
+        }
     }
 }