瀏覽代碼

';;' in a path is replaced by default path + '!' (in Windows) is
replaced by executable's directory

Roberto Ierusalimschy 20 年之前
父節點
當前提交
e4324f54b9
共有 1 個文件被更改,包括 46 次插入18 次删除
  1. 46 18
      loadlib.c

+ 46 - 18
loadlib.c

@@ -1,5 +1,5 @@
 /*
-** $Id: loadlib.c,v 1.30 2005/06/27 17:24:40 roberto Exp roberto $
+** $Id: loadlib.c,v 1.31 2005/07/05 19:29:03 roberto Exp roberto $
 ** Dynamic library loader for Lua
 ** See Copyright Notice in lua.h
 **
@@ -38,6 +38,9 @@
 #define LIB_FAIL	"open"
 
 
+#define setprogdir(L)		((void)0)
+
+
 static void ll_unloadlib (void *lib);
 static void *ll_load (lua_State *L, const char *path);
 static lua_CFunction ll_sym (lua_State *L, void *lib, const char *sym);
@@ -86,6 +89,21 @@ static lua_CFunction ll_sym (lua_State *L, void *lib, const char *sym) {
 */
 
 #include <windows.h>
+#include "Shlwapi.h"
+
+
+#undef setprogdir
+
+void setprogdir (lua_State *L) {
+  char buff[MAX_PATH + 1];
+  DWORD nsize = sizeof(buff)/sizeof(char);
+  DWORD n = GetModuleFileName(NULL, buff, nsize);
+  if (n == 0 || n == nsize)
+    luaL_error(L, "unable to get ModuleFileName");
+  PathRemoveFileSpec(buff);
+  luaL_gsub(L, lua_tostring(L, -1), LUA_EXECDIR, buff);
+  lua_remove(L, -2);  /* remove original string */
+}
 
 
 static void pusherror (lua_State *L) {
@@ -383,14 +401,18 @@ static void require_aux (lua_State *L, const char *name) {
       luaL_error(L, "package " LUA_QS " not found", name);
     lua_pushstring(L, name);
     lua_call(L, 1, 1);  /* call it */
-    if (lua_isnil(L, -1)) lua_pop(L, 1);
+    if (lua_isnil(L, -1)) lua_pop(L, 1);  /* did not found module */
     else break;  /* module loaded successfully */
   }
-  /* mark module as loaded */
   lua_pushboolean(L, 1);
   lua_setfield(L, loadedtable, name);  /* _LOADED[name] = true */
   lua_pushstring(L, name);  /* pass name as argument to module */
-  lua_call(L, 1, 1);  /* run loaded module */
+  if (lua_pcall(L, 1, 1, 0) != 0) {  /* run loaded module */
+    lua_pushnil(L);  /* in case of errors... */
+    lua_setfield(L, loadedtable, name);  /* ...clear _LOADED[name] */
+    luaL_error(L, "error loading package " LUA_QS " (%s)",
+                  name, lua_tostring(L, -1));  /* propagate error */
+  }
   if (!lua_isnil(L, -1))  /* non-nil return? */
     lua_setfield(L, loadedtable, name);  /* _LOADED[name] = returned value */
   lua_getfield(L, loadedtable, name);  /* return _LOADED[name] */
@@ -468,6 +490,22 @@ static int ll_module (lua_State *L) {
 /* }====================================================== */
 
 
+static void setpath (lua_State *L, const char *fname, const char *envname,
+                                   const char *def) {
+  const char *path = getenv(envname);
+  if (path == NULL) lua_pushstring(L, def);
+  else {
+    /* replace ";;" by default path */
+    path = luaL_gsub(L, path, LUA_PATHSEP LUA_PATHSEP,
+                              LUA_PATHSEP"\1"LUA_PATHSEP);
+    luaL_gsub(L, path, "\1", def);
+    lua_remove(L, -2);
+  }
+  setprogdir(L);
+  lua_setfield(L, -2, fname);
+}
+
+
 static const luaL_reg ll_funcs[] = {
   {"require", ll_require},
   {"module", ll_module},
@@ -476,11 +514,10 @@ static const luaL_reg ll_funcs[] = {
 
 
 static const lua_CFunction loaders[] =
-  {loader_preload, loader_C, loader_Lua, NULL};
+  {loader_preload, loader_Lua, loader_C, NULL};
 
 
 LUALIB_API int luaopen_loadlib (lua_State *L) {
-  const char *path;
   int i;
   /* create new type _LOADLIB */
   luaL_newmetatable(L, "_LOADLIB");
@@ -501,18 +538,9 @@ LUALIB_API int luaopen_loadlib (lua_State *L) {
     lua_pushcfunction(L, loaders[i]);
     lua_rawseti(L, -2, i+1);
   }
-  /* put it in field `loaders' */
-  lua_setfield(L, -2, "loaders");
-  /* set field `path' */
-  path = getenv(LUA_PATH);
-  if (path == NULL) path = LUA_PATH_DEFAULT;
-  lua_pushstring(L, path);
-  lua_setfield(L, -2, "path");
-  /* set field `cpath' */
-  path = getenv(LUA_CPATH);
-  if (path == NULL) path = LUA_CPATH_DEFAULT;
-  lua_pushstring(L, path);
-  lua_setfield(L, -2, "cpath");
+  lua_setfield(L, -2, "loaders");  /* put it in field `loaders' */
+  setpath(L, "path", LUA_PATH, LUA_PATH_DEFAULT);  /* set field `path' */
+  setpath(L, "cpath", LUA_CPATH, LUA_CPATH_DEFAULT); /* set field `cpath' */
   /* set field `loaded' */
   lua_getfield(L, LUA_REGISTRYINDEX, "_LOADED");
   lua_setfield(L, -2, "loaded");