Browse Source

feat: enhance module loading and package management

Akeit0 7 months ago
parent
commit
35c654dfba
3 changed files with 131 additions and 16 deletions
  1. 5 3
      src/Lua/LuaState.cs
  2. 112 8
      src/Lua/Standard/ModuleLibrary.cs
  3. 14 5
      src/Lua/Standard/OpenLibsExtensions.cs

+ 5 - 3
src/Lua/LuaState.cs

@@ -4,8 +4,8 @@ using System.Runtime.CompilerServices;
 using Lua.Internal;
 using Lua.Loaders;
 using Lua.Runtime;
+using Lua.Standard;
 using System.Buffers;
-using System.Text;
 
 namespace Lua;
 
@@ -15,7 +15,6 @@ public sealed class LuaState
     readonly LuaMainThread mainThread;
     FastListCore<UpValue> openUpValues;
     FastStackCore<LuaThread> threadStack;
-    readonly LuaTable packages = new();
     readonly LuaTable environment;
     readonly LuaTable registry = new();
     readonly UpValue envUpValue;
@@ -31,7 +30,8 @@ public sealed class LuaState
 
     public LuaTable Environment => environment;
     public LuaTable Registry => registry;
-    public LuaTable LoadedModules => packages;
+    public LuaTable LoadedModules => registry[ModuleLibrary.LoadedKeyForRegistry].Read<LuaTable>();
+    public LuaTable PreloadModules => registry[ModuleLibrary.PreloadKeyForRegistry].Read<LuaTable>();
     public LuaMainThread MainThread => mainThread;
 
     public LuaThreadAccess TopLevelAccess => new (mainThread, 0);
@@ -56,6 +56,8 @@ public sealed class LuaState
         mainThread = new(this);
         environment = new();
         envUpValue = UpValue.Closed(environment);
+        registry[ModuleLibrary.LoadedKeyForRegistry] = new LuaTable(0, 8);
+        registry[ModuleLibrary.PreloadKeyForRegistry] = new LuaTable(0, 8);
     }
 
 

+ 112 - 8
src/Lua/Standard/ModuleLibrary.cs

@@ -5,13 +5,17 @@ namespace Lua.Standard;
 public sealed class ModuleLibrary
 {
     public static readonly ModuleLibrary Instance = new();
+    internal const string LoadedKeyForRegistry = "_LOADED";
+    internal const string PreloadKeyForRegistry = "_PRELOAD";
 
     public ModuleLibrary()
     {
         RequireFunction = new("require", Require);
+        SearchPathFunction = new("searchpath", SearchPath);
     }
 
     public readonly LuaFunction RequireFunction;
+    public readonly LuaFunction SearchPathFunction;
 
     public async ValueTask<int> Require(LuaFunctionExecutionContext context, CancellationToken cancellationToken)
     {
@@ -20,18 +24,118 @@ public sealed class ModuleLibrary
 
         if (!loaded.TryGetValue(arg0, out var loadedTable))
         {
-            LuaClosure closure;
-            {
-                using var module = await context.State.ModuleLoader.LoadAsync(arg0, cancellationToken);
-                closure = module.Type == LuaModuleType.Bytes
-                    ? context.State.Load(module.ReadBytes(), module.Name)
-                    : context.State.Load(module.ReadText(), module.Name);
-            }
-            await context.Access.RunAsync(closure, 0, context.ReturnFrameBase, cancellationToken);
+            var loader = await FindLoader(context.Access, arg0, cancellationToken);
+            await context.Access.RunAsync(loader, 0, context.ReturnFrameBase, cancellationToken);
             loadedTable = context.Thread.Stack.Get(context.ReturnFrameBase);
             loaded[arg0] = loadedTable;
         }
 
         return context.Return(loadedTable);
     }
+
+    internal static async ValueTask<string?> FindFile(LuaThreadAccess access, string name, string pName, string dirSeparator)
+    {
+        var thread = access.Thread;
+        var state = thread.State;
+        var package = state.Environment["package"];
+        var p = await access.GetTable(package, pName);
+        if (!p.TryReadString(out var path))
+        {
+            throw new LuaRuntimeException(thread, ($"package.{pName} must be a string"));
+        }
+
+        return SearchPath(state, name, path, ".", dirSeparator);
+    }
+
+    public ValueTask<int> SearchPath(LuaFunctionExecutionContext context, CancellationToken cancellationToken)
+    {
+        var name = context.GetArgument<string>(0);
+        var path = context.GetArgument<string>(1);
+        var separator = context.GetArgument<string>(2);
+        var dirSeparator = context.GetArgument<string>(3);
+        var fileName = SearchPath(context.State, name, path, separator, dirSeparator);
+        return new(context.Return(fileName ?? LuaValue.Nil));
+    }
+
+    internal static string? SearchPath(LuaState state, string name, string path, string separator, string dirSeparator)
+    {
+        if (separator != "")
+        {
+            name = name.Replace(separator, dirSeparator);
+        }
+
+        var pathSpan = path.AsSpan();
+        var nextIndex = pathSpan.IndexOf(';');
+        if (nextIndex == -1) nextIndex = pathSpan.Length;
+        do
+        {
+            path = pathSpan[..nextIndex].ToString();
+            var fileName = path.Replace("?", name);
+            if (File.Exists(fileName))
+            {
+                return fileName;
+            }
+
+            if (pathSpan.Length <= nextIndex) break;
+            pathSpan = pathSpan[(nextIndex + 1)..];
+            nextIndex = pathSpan.IndexOf(';');
+            if (nextIndex == -1) nextIndex = pathSpan.Length;
+        } while (nextIndex != -1);
+
+        return null;
+    }
+
+    internal static async ValueTask<LuaFunction> FindLoader(LuaThreadAccess access, string name, CancellationToken cancellationToken)
+    {
+        var state = access.State;
+        var package = state.Environment["package"].Read<LuaTable>();
+        var searchers = package["searchers"].Read<LuaTable>();
+        for (int i = 0; i < searchers.GetArraySpan().Length; i++)
+        {
+            var searcher = searchers.GetArraySpan()[i];
+            if (searcher.Type == LuaValueType.Nil) continue;
+            var loader = searcher;
+            var top = access.Stack.Count;
+            access.Stack.Push(loader);
+            access.Stack.Push(name);
+            var resultCount = await access.Call(top, top, cancellationToken);
+            if (0 < resultCount)
+            {
+                var result = access.Stack.Get(top);
+                if (result.Type == LuaValueType.Function)
+                {
+                    access.Stack.SetTop(top);
+                    return result.Read<LuaFunction>();
+                }
+            }
+
+            access.Stack.SetTop(top);
+        }
+
+        throw new LuaRuntimeException(access.Thread, ($"Module '{name}' not found"));
+    }
+
+    public ValueTask<int> SearcherPreload(LuaFunctionExecutionContext context, CancellationToken cancellationToken)
+    {
+        var name = context.GetArgument<string>(0);
+        var preload = context.State.PreloadModules[name];
+        if (preload == LuaValue.Nil)
+        {
+            return new(context.Return());
+        }
+
+        return new(context.Return(preload));
+    }
+
+    public async ValueTask<int> SearcherLua(LuaFunctionExecutionContext context, CancellationToken cancellationToken)
+    {
+        var name = context.GetArgument<string>(0);
+        var fileName = await FindFile(context.Access, name, "path", Path.DirectorySeparatorChar.ToString());
+        if (fileName == null)
+        {
+            return (context.Return(LuaValue.Nil));
+        }
+
+        return context.Return(await context.State.LoadFileAsync(fileName, "bt", null, cancellationToken));
+    }
 }

+ 14 - 5
src/Lua/Standard/OpenLibsExtensions.cs

@@ -47,9 +47,9 @@ public static class OpenLibsExtensions
         }
 
         var registry = state.Registry;
-        registry ["stdin"] = new (new FileHandle(ConsoleHelper.OpenStandardInput()));
-        registry["stdout"] =new (new FileHandle(ConsoleHelper.OpenStandardOutput()));
-        registry["stderr"] = new (new FileHandle(ConsoleHelper.OpenStandardError()));
+        registry["stdin"] = new(new FileHandle(ConsoleHelper.OpenStandardInput()));
+        registry["stdout"] = new(new FileHandle(ConsoleHelper.OpenStandardOutput()));
+        registry["stderr"] = new(new FileHandle(ConsoleHelper.OpenStandardError()));
 
         state.Environment["io"] = io;
         state.LoadedModules["io"] = io;
@@ -74,10 +74,19 @@ public static class OpenLibsExtensions
 
     public static void OpenModuleLibrary(this LuaState state)
     {
-        var package = new LuaTable();
+        var package = new LuaTable(0, 8);
         package["loaded"] = state.LoadedModules;
+        package["preload"] = state.PreloadModules;
+        var moduleLibrary = ModuleLibrary.Instance;
+        var searchers = new LuaTable();
+        searchers[1] = new LuaFunction("preload", moduleLibrary.SearcherPreload);
+        searchers[2] = new LuaFunction("searcher_Lua", moduleLibrary.SearcherLua);
+        package["searchers"] = searchers;
+        package["path"] = "?.lua";
+        package["searchpath"] = moduleLibrary.SearchPathFunction;
+        package["config"] = $"{Path.DirectorySeparatorChar}\n;\n?\n!\n-";
         state.Environment["package"] = package;
-        state.Environment["require"] = ModuleLibrary.Instance.RequireFunction;
+        state.Environment["require"] = moduleLibrary.RequireFunction;
     }
 
     public static void OpenOperatingSystemLibrary(this LuaState state)