Browse Source

added calldownEnforced, EnforcesCalldowns

Darren Ranalli 19 years ago
parent
commit
5e751cb8ad
2 changed files with 219 additions and 96 deletions
  1. 12 3
      direct/src/distributed/DistributedObject.py
  2. 207 93
      direct/src/showbase/PythonUtil.py

+ 12 - 3
direct/src/distributed/DistributedObject.py

@@ -3,6 +3,7 @@
 from pandac.PandaModules import *
 from direct.directnotify.DirectNotifyGlobal import directNotify
 from direct.distributed.DistributedObjectBase import DistributedObjectBase
+from direct.showbase.PythonUtil import EnforcesCalldowns, calldownEnforced
 #from PyDatagram import PyDatagram
 #from PyDatagramIterator import PyDatagramIterator
 
@@ -15,7 +16,7 @@ ESDisabled     = 4  # values here and lower are considered "disabled"
 ESGenerating   = 5  # values here and greater are considered "generated"
 ESGenerated    = 6
 
-class DistributedObject(DistributedObjectBase):
+class DistributedObject(DistributedObjectBase, EnforcesCalldowns):
     """
     The Distributed Object class is the base class for all network based
     (i.e. distributed) objects.  These will usually (always?) have a
@@ -37,6 +38,8 @@ class DistributedObject(DistributedObjectBase):
             self.DistributedObject_initialized = 1
             DistributedObjectBase.__init__(self, cr)
 
+            EnforcesCalldowns.__init__(self)
+
             # Most DistributedObjects are simple and require no real
             # effort to load.  Some, particularly actors, may take
             # some significant time to load; these we can optimize by
@@ -200,6 +203,7 @@ class DistributedObject(DistributedObjectBase):
             messenger.send(self.uniqueName("disable"))
             self.disable()
 
+    @calldownEnforced
     def announceGenerate(self):
         """
         Sends a message to the world after the object has been
@@ -210,6 +214,7 @@ class DistributedObject(DistributedObjectBase):
             self.activeState = ESGenerated
             messenger.send(self.uniqueName("generate"), [self])
 
+    @calldownEnforced
     def disable(self):
         """
         Inheritors should redefine this to take appropriate action on disable
@@ -238,6 +243,7 @@ class DistributedObject(DistributedObjectBase):
         assert self.notify.debugStateCall(self)
         return (self.activeState == ESGenerated)
 
+    @calldownEnforced
     def delete(self):
         """
         Inheritors should redefine this to take appropriate action on delete
@@ -249,7 +255,9 @@ class DistributedObject(DistributedObjectBase):
             self.DistributedObject_deleted = 1
             self.cr = None
             self.dclass = None
+            EnforcesCalldowns.destroy(self)
 
+    @calldownEnforced
     def generate(self):
         """
         Inheritors should redefine this to take appropriate action on generate
@@ -257,11 +265,12 @@ class DistributedObject(DistributedObjectBase):
         assert self.notify.debugStateCall(self)
         self.activeState = ESGenerating
         # this has already been set at this point
-        #self.cr.storeObjectLocation(self.doId, self.parentId, self.zoneId)
-        # HACK: we seem to be calling generate() more than once for objects that multiply-inherit
+        #self.cr.storeObjectLocation(self, self.parentId, self.zoneId)
+        # semi-hack: we seem to be calling generate() more than once for objects that multiply-inherit
         if not hasattr(self, '_autoInterestHandle'):
             self.cr.openAutoInterests(self)
 
+    @calldownEnforced
     def generateInit(self):
         """
         This method is called when the DistributedObject is first introduced

+ 207 - 93
direct/src/showbase/PythonUtil.py

@@ -2305,99 +2305,6 @@ class ArgumentEater:
     def __call__(self, *args, **kwArgs):
         self._func(*args[self._numToEat:], **kwArgs)
 
-class HasCheckpoints:
-    """Derive from this class if you want to ensure that specific methods get called.
-    See checkpoint/etc. decorators below"""
-    def __init__(self):
-        self._checkpoints = {}
-        self._CPinit = True
-    def destroyCP(self):
-        del self._checkpoints
-    def _getCPFuncName(self, nameOrFunc):
-        if type(nameOrFunc) == types.StringType:
-            name = nameOrFunc
-        else:
-            func = nameOrFunc
-            name = '%s.%s' % (func.__module__, func.__name__)
-        return name
-    def _CPdestroyed(self):
-        # if self._CPinit does not exist here, our __init__ was not called
-        if self._CPinit and not hasattr(self, '_checkpoints'):
-            # this can occur if a decorated method is called after this
-            # base class has been destroyed.
-            return True
-        return False
-    def CPreset(self, func):
-        name = self._getCPFuncName(func)
-        if name in self._checkpoints:
-            del self._checkpoints[name]
-        return True
-    def CPvisit(self, func):
-        if self._CPdestroyed():
-            return True
-        name = self._getCPFuncName(func)
-        self._checkpoints.setdefault(name, 0)
-        self._checkpoints[name] += 1
-        return True
-    # check if a particular method was called
-    def CPcheck(self, func):
-        if self._CPdestroyed():
-            return True
-        name = self._getCPFuncName(func)
-        if self._checkpoints.get(name) is None:
-            __builtin__.tree = Functor(ClassTree, self)
-            raise ('\n%s not called for %s.%s\n call tree() to view %s class hierarchy' % (
-                name, self.__module__, self.__class__.__name__, self.__class__.__name__))
-        self.CPreset(name)
-        return True
-
-def checkpoint(f):
-    """
-    Use this decorator to track if a particular method has been called.
-    Class must derive from HasCheckpoints.
-    """
-    def _checkpoint(obj, *args, **kArgs):
-        obj.CPvisit(f)
-        return f(obj, *args, **kArgs)
-    _checkpoint.__doc__ = f.__doc__
-    _checkpoint.__name__ = f.__name__
-    _checkpoint.__module__ = f.__module__
-    return _checkpoint
-
-def calldownEnforced(f):
-    """
-    Use this decorator to ensure that derived classes that override this method
-    call down to the base class method.
-    """
-    # TODO
-    return f
-
-if __debug__:
-    class CheckPointTest(HasCheckpoints):
-        @checkpoint
-        def testFunc(self):
-            pass
-    cpt = CheckPointTest()
-    raised = True
-    try:
-        cpt.CPcheck(CheckPointTest.testFunc)
-        raised = False
-    except:
-        pass
-    if not raised:
-        raise 'CPcheck failed to raise'
-    cpt.testFunc()
-    cpt.CPcheck(CheckPointTest.testFunc)
-    try:
-        cpt.CPcheck('testFunc')
-        raised = False
-    except:
-        pass
-    if not raised:
-        raise 'CPcheck failed to raise'
-    del cpt
-    del CheckPointTest
-
 class ClassTree:
     def __init__(self, instanceOrClass):
         if type(instanceOrClass) in (types.ClassType, types.TypeType):
@@ -2409,6 +2316,13 @@ class ClassTree:
         for base in self._cls.__bases__:
             if base not in (types.ObjectType, types.TypeType):
                 self._bases.append(ClassTree(base))
+    def getAllClasses(self):
+        # returns set of this class and all base classes
+        classes = set()
+        classes.add(self._cls)
+        for base in self._bases:
+            classes.update(base.getAllClasses())
+        return classes
     def _getStr(self, indent=None, clsLeftAtIndent=None):
         # indent is how far to the right to indent (i.e. how many levels
         # deep in the hierarchy from the most-derived)
@@ -2433,6 +2347,16 @@ class ClassTree:
             s += ' +'
         s += self._cls.__name__
         clsLeftAtIndent[indent] -= 1
+        """
+        ### show the module to the right of the class name
+        moduleIndent = 48
+        if len(s) >= moduleIndent:
+            moduleIndent = (len(s) % 4) + 4
+        padding = moduleIndent - len(s)
+        s += padding * ' '
+        s += self._cls.__module__
+        ###
+        """
         if len(self._bases):
             newList = list(clsLeftAtIndent)
             newList.append(len(self._bases))
@@ -2445,6 +2369,196 @@ class ClassTree:
     def __repr__(self):
         return self._getStr()
 
+class EnforcedCalldownException(Exception):
+    def __init__(self, what):
+        Exception.__init__(self, what)
+
+class EnforcesCalldowns:
+    """Derive from this class if you want to ensure that specific methods
+    get called.  See calldownEnforced decorator below"""
+
+    # class-level data for enforcement of base class method call-down
+    #
+    # The problem is that we don't have access to the class in the
+    # decorator, so we need to put the decorated methods in a global
+    # dict. We can then insert a stub method on each class instance for
+    # every method that has enforced base-class methods, and the stub can
+    # watch for each base-class method checkpoint to be passed.
+
+    # since the decorator can't know its own id until after it has been
+    # defined, we map from decorator ID to original func ID
+    _decoId2funcId = {}
+    # as calldownEnforced decorators are created, they add themselves to
+    # this dict. At this point we don't know what class they belong to.
+    _funcId2func = {}
+    # this is here so that we can print nice error messages
+    _funcId2class = {}
+    # as class instances are created, we populate this dictionary of
+    # class to func name to list of func ids. The lists of func ids
+    # include base-class funcs.
+    _class2funcName2funcIds = {}
+
+    # this method will be inserted into instances of classes that need
+    # to enforce base-class method calls, as the most-derived implementation
+    # of the method
+    @staticmethod
+    def _enforceCalldowns(oldMethod, name, obj, *args, **kArgs):
+        name2funcIds = EnforcesCalldowns._class2funcName2funcIds[obj.__class__]
+        funcIds = name2funcIds.get(name)
+
+        # prepare for the method call
+        for funcId in funcIds:
+            obj._EClatch(funcId)
+
+        # call the actual method that we're stubbing
+        result = oldMethod(*args, **kArgs)
+
+        # check on the results
+        for funcId in funcIds:
+            obj._ECcheck(funcId)
+
+        return result
+
+    def __init__(self):
+        if not __debug__:
+            return
+
+        # this map tracks how many times each func has been called
+        self._funcId2calls = {}
+        # this map tracks the 'latch' values for each func; if the call count
+        # for a func is greater than the latch, then the func has been called.
+        self._funcId2latch = {}
+
+        if self.__class__ not in EnforcesCalldowns._class2funcName2funcIds:
+            # prepare stubs to enforce method call-downs
+            EnforcesCalldowns._class2funcName2funcIds.setdefault(self.__class__, {})
+            # look through all of our base classes and find matches
+            classes = ClassTree(self).getAllClasses()
+            # collect IDs of all the enforced methods
+            funcId2func = {}
+            for cls in classes:
+                for name, item in cls.__dict__.items():
+                    if id(item) in EnforcesCalldowns._decoId2funcId:
+                        funcId = EnforcesCalldowns._decoId2funcId[id(item)]
+                        funcId2func[funcId] = item
+                        EnforcesCalldowns._funcId2class[funcId] = cls
+            # add these funcs to the list for our class
+            funcName2funcIds = EnforcesCalldowns._class2funcName2funcIds[self.__class__]
+            for funcId, func in funcId2func.items():
+                funcName2funcIds.setdefault(func.__name__, [])
+                funcName2funcIds[func.__name__].append(funcId)
+
+        # now run through all the enforced funcs for this class and insert
+        # stub methods to do the enforcement
+        funcName2funcIds = EnforcesCalldowns._class2funcName2funcIds[self.__class__]
+        self._obscuredMethodNames = set()
+        for name in funcName2funcIds:
+            oldMethod = getattr(self, name)
+            self._obscuredMethodNames.add(name)
+            setattr(self, name, new.instancemethod(
+                Functor(EnforcesCalldowns._enforceCalldowns, oldMethod, name),
+                self, self.__class__))
+            
+    def destroy(self):
+        # this must be called on destruction to prevent memory leaks
+        #import pdb;pdb.set_trace()
+        for name in self._obscuredMethodNames:
+            delattr(self, name)
+        del self._obscuredMethodNames
+        # this opens up more cans of worms. Let's keep it closed for the moment
+        #del self._funcId2calls
+        #del self._funcId2latch
+
+    def skipCalldown(self, method):
+        if not __debug__:
+            return
+        # Call this function if you really don't want to call down to an
+        # enforced base-class method. This should hardly ever be used.
+        funcName2funcIds = EnforcesCalldowns._class2funcName2funcIds[self.__class__]
+        funcIds = funcName2funcIds[method.__name__]
+        for funcId in funcIds:
+            self._ECvisit(funcId)
+
+    def _EClatch(self, funcId):
+        self._funcId2calls.setdefault(funcId, 0)
+        self._funcId2latch[funcId] = self._funcId2calls[funcId]
+    def _ECvisit(self, funcId):
+        self._funcId2calls.setdefault(funcId, 0)
+        self._funcId2calls[funcId] += 1
+    def _ECcheck(self, funcId):
+        if self._funcId2latch[funcId] == self._funcId2calls[funcId]:
+            func = EnforcesCalldowns._funcId2func[funcId]
+            raise EnforcedCalldownException(
+                '%s.%s did not call down to %s.%s\n%s' % (
+                self.__class__.__module__, self.__class__.__name__,
+                EnforcesCalldowns._funcId2class[funcId].__name__,
+                func.__name__,
+                ClassTree(self)))
+
+def calldownEnforced(f):
+    """
+    Use this decorator to ensure that derived classes that override this method
+    call down to the base class method.
+    """
+    if not __debug__:
+        return f
+    def calldownEnforcedImpl(obj, *args, **kArgs):
+        # track the fact that this func has been called
+        obj._ECvisit(id(f))
+        f(obj, *args, **kArgs)
+    calldownEnforcedImpl.__doc__ = f.__doc__
+    calldownEnforcedImpl.__name__ = f.__name__
+    calldownEnforcedImpl.__module__ = f.__module__
+    EnforcesCalldowns._decoId2funcId[id(calldownEnforcedImpl)] = id(f)
+    EnforcesCalldowns._funcId2func[id(f)] = calldownEnforcedImpl
+    return calldownEnforcedImpl
+
+if __debug__:
+    class CalldownEnforceTest(EnforcesCalldowns):
+        @calldownEnforced
+        def testFunc(self):
+            pass
+    class CalldownEnforceTestSubclass(CalldownEnforceTest):
+        def testFunc(self):
+            CalldownEnforceTest.testFunc(self)
+    class CalldownEnforceTestSubclassFail(CalldownEnforceTest):
+        def testFunc(self):
+            pass
+    class CalldownEnforceTestSubclassSkip(CalldownEnforceTest):
+        def testFunc(self):
+            self.skipCalldown(CalldownEnforceTest.testFunc)
+    cets = CalldownEnforceTestSubclass()
+    cetsf = CalldownEnforceTestSubclassFail()
+    cetss = CalldownEnforceTestSubclassSkip()
+    raised = False
+    try:
+        cets.testFunc()
+    except EnforcedCalldownException, e:
+        raised = True
+    if raised:
+        raise "calldownEnforced raised when it shouldn't"
+    raised = False
+    try:
+        cetsf.testFunc()
+    except EnforcedCalldownException, e:
+        raised = True
+    if not raised:
+        raise 'calldownEnforced failed to raise'
+    raised = False
+    try:
+        cetss.testFunc()
+    except EnforcedCalldownException, e:
+        raised = True
+    if raised:
+        raise "calldownEnforced.skipCalldown raised when it shouldn't"
+    del cetss
+    del cetsf
+    del cets
+    del CalldownEnforceTestSubclassSkip
+    del CalldownEnforceTestSubclassFail
+    del CalldownEnforceTestSubclass
+    del CalldownEnforceTest
+
 import __builtin__
 __builtin__.Functor = Functor
 __builtin__.Stack = Stack