Browse Source

added checkpoint, ClassTree

Darren Ranalli 19 years ago
parent
commit
de70d67d0f
1 changed files with 140 additions and 0 deletions
  1. 140 0
      direct/src/showbase/PythonUtil.py

+ 140 - 0
direct/src/showbase/PythonUtil.py

@@ -2305,6 +2305,146 @@ class ArgumentEater:
     def __call__(self, *args, **kwArgs):
         self._func(*args[self._numToEat:], **kwArgs)
 
+class HasCheckpoints:
+    """Derive from this class if you want to ensure that specific methods get called.
+    See checkpoint/etc. decorators below"""
+    def __init__(self):
+        self._checkpoints = {}
+        self._CPinit = True
+    def destroyCP(self):
+        del self._checkpoints
+    def _getCPFuncName(self, nameOrFunc):
+        if type(nameOrFunc) == types.StringType:
+            name = nameOrFunc
+        else:
+            func = nameOrFunc
+            name = '%s.%s' % (func.__module__, func.__name__)
+        return name
+    def _CPdestroyed(self):
+        # if self._CPinit does not exist here, our __init__ was not called
+        if self._CPinit and not hasattr(self, '_checkpoints'):
+            # this can occur if a decorated method is called after this
+            # base class has been destroyed.
+            return True
+        return False
+    def CPreset(self, func):
+        name = self._getCPFuncName(func)
+        if name in self._checkpoints:
+            del self._checkpoints[name]
+        return True
+    def CPvisit(self, func):
+        if self._CPdestroyed():
+            return True
+        name = self._getCPFuncName(func)
+        self._checkpoints.setdefault(name, 0)
+        self._checkpoints[name] += 1
+        return True
+    # check if a particular method was called
+    def CPcheck(self, func):
+        if self._CPdestroyed():
+            return True
+        name = self._getCPFuncName(func)
+        if self._checkpoints.get(name) is None:
+            __builtin__.tree = Functor(ClassTree, self)
+            raise ('\n%s not called for %s.%s\n call tree() to view %s class hierarchy' % (
+                name, self.__module__, self.__class__.__name__, self.__class__.__name__))
+        self.CPreset(name)
+        return True
+
+def checkpoint(f):
+    """
+    Use this decorator to track if a particular method has been called.
+    Class must derive from HasCheckpoints.
+    """
+    def _checkpoint(obj, *args, **kArgs):
+        obj.CPvisit(f)
+        return f(obj, *args, **kArgs)
+    _checkpoint.__doc__ = f.__doc__
+    _checkpoint.__name__ = f.__name__
+    _checkpoint.__module__ = f.__module__
+    return _checkpoint
+
+def calldownEnforced(f):
+    """
+    Use this decorator to ensure that derived classes that override this method
+    call down to the base class method.
+    """
+    # TODO
+    return f
+
+if __debug__:
+    class CheckPointTest(HasCheckpoints):
+        @checkpoint
+        def testFunc(self):
+            pass
+    cpt = CheckPointTest()
+    raised = True
+    try:
+        cpt.CPcheck(CheckPointTest.testFunc)
+        raised = False
+    except:
+        pass
+    if not raised:
+        raise 'CPcheck failed to raise'
+    cpt.testFunc()
+    cpt.CPcheck(CheckPointTest.testFunc)
+    try:
+        cpt.CPcheck('testFunc')
+        raised = False
+    except:
+        pass
+    if not raised:
+        raise 'CPcheck failed to raise'
+    del cpt
+    del CheckPointTest
+
+class ClassTree:
+    def __init__(self, instanceOrClass):
+        if type(instanceOrClass) in (types.ClassType, types.TypeType):
+            cls = instanceOrClass
+        else:
+            cls = instanceOrClass.__class__
+        self._cls = cls
+        self._bases = []
+        for base in self._cls.__bases__:
+            if base not in (types.ObjectType, types.TypeType):
+                self._bases.append(ClassTree(base))
+    def _getStr(self, indent=None, clsLeftAtIndent=None):
+        # indent is how far to the right to indent (i.e. how many levels
+        # deep in the hierarchy from the most-derived)
+        #
+        # clsLeftAtIndent is an array of # of classes left to be
+        # printed at each level of the hierarchy; most-derived is
+        # at index 0
+        if indent is None:
+            indent = 0
+            clsLeftAtIndent = [1]
+        s = ''
+        if (indent > 1):
+            for i in range(1, indent):
+                # if we have not printed all base classes at
+                # this indent level, keep printing the vertical
+                # column
+                if clsLeftAtIndent[i] > 0:
+                    s += ' |'
+                else:
+                    s += '  '
+        if (indent > 0):
+            s += ' +'
+        s += self._cls.__name__
+        clsLeftAtIndent[indent] -= 1
+        if len(self._bases):
+            newList = list(clsLeftAtIndent)
+            newList.append(len(self._bases))
+            bases = self._bases
+            # print classes with fewer bases first
+            bases.sort(lambda x,y: len(x._bases)-len(y._bases))
+            for base in bases:
+                s += '\n%s' % base._getStr(indent+1, newList)
+        return s
+    def __repr__(self):
+        return self._getStr()
+
 import __builtin__
 __builtin__.Functor = Functor
 __builtin__.Stack = Stack