Browse Source

Added expand-imports functionality

Josh Yelon 19 years ago
parent
commit
b261d92f98
1 changed files with 144 additions and 22 deletions
  1. 144 22
      direct/src/directscripts/gendocs.py

+ 144 - 22
direct/src/directscripts/gendocs.py

@@ -58,6 +58,8 @@ import os, sys, parser, symbol, token, types, re
 
 SECHEADER = re.compile("^[A-Z][a-z]+\s*:")
 JUNKHEADER = re.compile("^((Function)|(Access))\s*:")
+IMPORTSTAR = re.compile("^from\s+([a-zA-Z0-9_.]+)\s+import\s+[*]\s*$")
+IDENTIFIER = re.compile("[a-zA-Z0-9_]+")
 
 def readFile(fn):
     try:
@@ -76,6 +78,16 @@ def writeFile(wfile, data):
     except:
         sys.exit("Cannot write "+wfile)
 
+def writeFileLines(wfile, lines):
+    try:
+        dsthandle = open(wfile, "wb")
+        for x in lines:
+            dsthandle.write(x)
+            dsthandle.write("\n")
+        dsthandle.close()
+    except:
+        sys.exit("Cannot write "+wfile)
+
 def findFiles(dirlist, ext, ign, list):
     if isinstance(dirlist, types.StringTypes):
         dirlist = [dirlist]
@@ -89,6 +101,12 @@ def findFiles(dirlist, ext, ign, list):
                 elif (os.path.isdir(full)):
                     findFiles(full, ext, ign, list)
 
+def pathToModule(result):
+    if (result[-3:]==".py"): result=result[:-3]
+    result = result.replace("/src/","/")
+    result = result.replace("/",".")
+    return result
+
 def textToHTML(comment, sep, delsection=None):
     sections = [""]
     included = {}
@@ -367,7 +385,31 @@ DERIVATION_PATTERN = (
                 (token.NAME, ['classname'])
    )))))))))))))
 
-
+ASSIGNMENT_STMT_PATTERN = (
+    symbol.stmt,
+    (symbol.simple_stmt,
+     (symbol.small_stmt,
+      (symbol.expr_stmt,
+       (symbol.testlist,
+        (symbol.test,
+         (symbol.and_test,
+          (symbol.not_test,
+           (symbol.comparison,
+            (symbol.expr,
+             (symbol.xor_expr,
+              (symbol.and_expr,
+               (symbol.shift_expr,
+                (symbol.arith_expr,
+                 (symbol.term,
+                  (symbol.factor,
+                   (symbol.power,
+                    (symbol.atom,
+                     (token.NAME, ['varname']),
+       )))))))))))))),
+       (token.EQUAL, '='),
+       (symbol.testlist, ['rhs']))),
+     (token.NEWLINE, ''),
+   ))
 
 class ParseTreeInfo:
     docstring = ''
@@ -382,6 +424,7 @@ class ParseTreeInfo:
         self.file = file
         self.class_info = {}
         self.function_info = {}
+        self.assign_info = {}
         self.derivs = {}
         if isinstance(tree, types.StringType):
             try:
@@ -440,6 +483,9 @@ class ParseTreeInfo:
             self.docstring = eval(vars['docstring'])
         # discover inner definitions
         for node in tree[1:]:
+            found, vars = self.match(ASSIGNMENT_STMT_PATTERN, node)
+            if found:
+                self.assign_info[vars['varname']] = 1
             found, vars = self.match(COMPOUND_STMT_PATTERN, node)
             if found:
                 cstmt = vars['compound']
@@ -486,6 +532,9 @@ class CodeDatabase:
         self.types = {}
         self.funcs = {}
         self.goodtypes = {}
+        self.funcExports = {}
+        self.typeExports = {}
+        self.varExports = {}
         self.globalfn = []
         print "Reading C++ source files"
         for cxx in cxxlist:
@@ -496,26 +545,33 @@ class CodeDatabase:
                     self.types[type.scopedname] = type
                 if (type.flags & 8192) and (type.atomictype == 0) and (type.scopedname.count(" ")==0) and (type.scopedname.count(":")==0):
                     self.goodtypes[type.scopedname] = type
+                    self.typeExports.setdefault("pandac.PandaModules", []).append(type.scopedname)
             for func in idb.functions.values():
                 type = idb.types.get(func.classindex)
                 func.pyname = convertToPythonFn(func.componentname)
                 if (type == None):
                     self.funcs["GLOBAL."+func.pyname] = func
                     self.globalfn.append("GLOBAL."+func.pyname)
+                    self.funcExports.setdefault("pandac.PandaModules", []).append(func.pyname)
                 else:
                     self.funcs[type.scopedname+"."+func.pyname] = func
         print "Reading Python sources files"
         for py in pylist:
             pyinf = ParseTreeInfo(readFile(py), py, py)
+            mod = pathToModule(py)
             for type in pyinf.class_info.keys():
                 typinf = pyinf.class_info[type]
                 self.types[type] = typinf
                 self.goodtypes[type] = typinf
+                self.typeExports.setdefault(mod, []).append(type)
                 for func in typinf.function_info.keys():
                     self.funcs[type+"."+func] = typinf.function_info[func]
             for func in pyinf.function_info.keys():
                 self.funcs["GLOBAL."+func] = pyinf.function_info[func]
                 self.globalfn.append("GLOBAL."+func)
+                self.funcExports.setdefault(mod, []).append(func)
+            for var in pyinf.assign_info.keys():
+                self.varExports.setdefault(mod, []).append(var)
 
     def getClassList(self):
         return self.goodtypes.keys()
@@ -576,11 +632,7 @@ class CodeDatabase:
         if (isinstance(type, InterrogateType)):
             return "pandac.PandaModules"
         else:
-            result = type.file
-            if (result[-3:]==".py"): result=result[:-3]
-            result = result.replace("/src/","/")
-            result = result.replace("/",".")
-            return result
+            return pathToModule(type.file)
 
     def getClassMethods(self, cn):
         type = self.types.get(cn)
@@ -604,6 +656,13 @@ class CodeDatabase:
         else:
             return fn
 
+    def getFunctionImport(self, fn):
+        func = self.funcs.get(fn)
+        if (isinstance(func, InterrogateFunction)):
+            return "pandac.PandaModules"
+        else:
+            return pathToModule(func.file)
+
     def getFunctionPrototype(self, fn):
         func = self.funcs.get(fn)
         if (isinstance(func, InterrogateFunction)):
@@ -624,6 +683,15 @@ class CodeDatabase:
             return textToHTML(func.docstring, "#")
         return fn
 
+    def getFuncExports(self, mod):
+        return self.funcExports.get(mod, [])
+
+    def getTypeExports(self, mod):
+        return self.typeExports.get(mod, [])
+
+    def getVarExports(self, mod):
+        return self.varExports.get(mod, [])
+
 ########################################################################
 #
 # The "Class Rename Dictionary" - Yech.
@@ -672,6 +740,23 @@ CLASS_RENAME_DICT = {
 #
 ########################################################################
 
+def makeCodeDatabase(indirlist, directdirlist):
+    if isinstance(directdirlist, types.StringTypes):
+        directdirlist = [directdirlist]
+    ignore = {}
+    ignore["__init__.py"] = 1
+    for directdir in directdirlist:
+        ignore[directdir + "/src/directscripts"] = 1
+        ignore[directdir + "/src/extensions"] = 1
+        ignore[directdir + "/src/extensions_native"] = 1
+        ignore[directdir + "/src/ffi"] = 1
+        ignore[directdir + "/built"] = 1
+    cxxfiles = []
+    pyfiles = []
+    findFiles(indirlist,     ".in", ignore, cxxfiles)
+    findFiles(directdirlist, ".py", ignore, pyfiles)
+    return CodeDatabase(cxxfiles, pyfiles)
+
 def generateFunctionDocs(code, method):
     name = code.getFunctionName(method)
     proto = code.getFunctionPrototype(method)
@@ -704,23 +789,8 @@ def generateLinkTable(link, text, cols, urlprefix, urlsuffix):
     result = result + "</table>\n"
     return result
 
-
 def generate(pversion, indirlist, directdirlist, docdir, header, footer, urlprefix, urlsuffix):
-    if isinstance(directdirlist, types.StringTypes):
-        directdirlist = [directdirlist]
-    ignore = {}
-    ignore["__init__.py"] = 1
-    for directdir in directdirlist:
-        ignore[directdir + "/src/directscripts"] = 1
-        ignore[directdir + "/src/extensions"] = 1
-        ignore[directdir + "/src/extensions_native"] = 1
-        ignore[directdir + "/src/ffi"] = 1
-        ignore[directdir + "/built"] = 1
-    cxxfiles = []
-    pyfiles = []
-    findFiles(indirlist,     ".in", ignore, cxxfiles)
-    findFiles(directdirlist, ".py", ignore, pyfiles)
-    code = CodeDatabase(cxxfiles, pyfiles)
+    code = makeCodeDatabase(indirlist, directdirlist)
     classes = code.getClassList()[:]
     classes.sort(None, str.lower)
     xclasses = classes[:]
@@ -826,3 +896,55 @@ def generate(pversion, indirlist, directdirlist, docdir, header, footer, urlpref
     index = index + "<li>" + linkTo(urlprefix+"methods"+urlsuffix, "List of all Methods (very long)") + "\n"
     index = index + "</ul>\n"
     writeFile(docdir + "/index.html", index)
+
+
+########################################################################
+#
+# IMPORT repair
+#
+########################################################################
+
+def expandImports(indirlist, directdirlist, fixdirlist):
+    code = makeCodeDatabase(indirlist, directdirlist)
+    fixfiles = []
+    findFiles(fixdirlist, ".py", {}, fixfiles)
+    for fixfile in fixfiles:
+        if (os.path.isfile(fixfile+".orig")):
+            text = readFile(fixfile+".orig")
+        else:
+            text = readFile(fixfile)
+            writeFile(fixfile+".orig", text)
+        text = text.replace("\r","")
+        lines = text.split("\n")
+        used = {}
+        for id in IDENTIFIER.findall(text):
+            used[id] = 1
+        result = []
+        for line in lines:
+            mat = IMPORTSTAR.match(line)
+            if (mat):
+                module = mat.group(1)
+                if (fixfile.count("/")!=0) and (module.count(".")==0):
+                    modfile = os.path.dirname(fixfile)+"/"+module+".py"
+                    if (os.path.isfile(modfile)):
+                        module = pathToModule(modfile)
+                typeExports = code.getTypeExports(module)
+                funcExports = code.getFuncExports(module)
+                varExports = code.getVarExports(module)
+                if (len(typeExports)+len(funcExports)+len(varExports)==0):
+                    result.append(line)
+                else:
+                    print "modifying "+fixfile
+                    for x in funcExports:
+                        fn = code.getFunctionName(x)
+                        if (used.has_key(fn)):
+                            result.append("from "+module+" import "+fn)
+                    for x in typeExports:
+                        if (used.has_key(x)):
+                            result.append("from "+module+" import "+x)
+                    for x in varExports:
+                        if (used.has_key(x)):
+                            result.append("from "+module+" import "+x)
+            else:
+                result.append(line)
+        writeFileLines(fixfile, result)