Browse Source

added calldownEnforced, EnforcesCalldowns

Darren Ranalli 19 years ago
parent
commit
5e751cb8ad
2 changed files with 219 additions and 96 deletions
  1. 12 3
      direct/src/distributed/DistributedObject.py
  2. 207 93
      direct/src/showbase/PythonUtil.py

+ 12 - 3
direct/src/distributed/DistributedObject.py

@@ -3,6 +3,7 @@
 from pandac.PandaModules import *
 from pandac.PandaModules import *
 from direct.directnotify.DirectNotifyGlobal import directNotify
 from direct.directnotify.DirectNotifyGlobal import directNotify
 from direct.distributed.DistributedObjectBase import DistributedObjectBase
 from direct.distributed.DistributedObjectBase import DistributedObjectBase
+from direct.showbase.PythonUtil import EnforcesCalldowns, calldownEnforced
 #from PyDatagram import PyDatagram
 #from PyDatagram import PyDatagram
 #from PyDatagramIterator import PyDatagramIterator
 #from PyDatagramIterator import PyDatagramIterator
 
 
@@ -15,7 +16,7 @@ ESDisabled     = 4  # values here and lower are considered "disabled"
 ESGenerating   = 5  # values here and greater are considered "generated"
 ESGenerating   = 5  # values here and greater are considered "generated"
 ESGenerated    = 6
 ESGenerated    = 6
 
 
-class DistributedObject(DistributedObjectBase):
+class DistributedObject(DistributedObjectBase, EnforcesCalldowns):
     """
     """
     The Distributed Object class is the base class for all network based
     The Distributed Object class is the base class for all network based
     (i.e. distributed) objects.  These will usually (always?) have a
     (i.e. distributed) objects.  These will usually (always?) have a
@@ -37,6 +38,8 @@ class DistributedObject(DistributedObjectBase):
             self.DistributedObject_initialized = 1
             self.DistributedObject_initialized = 1
             DistributedObjectBase.__init__(self, cr)
             DistributedObjectBase.__init__(self, cr)
 
 
+            EnforcesCalldowns.__init__(self)
+
             # Most DistributedObjects are simple and require no real
             # Most DistributedObjects are simple and require no real
             # effort to load.  Some, particularly actors, may take
             # effort to load.  Some, particularly actors, may take
             # some significant time to load; these we can optimize by
             # some significant time to load; these we can optimize by
@@ -200,6 +203,7 @@ class DistributedObject(DistributedObjectBase):
             messenger.send(self.uniqueName("disable"))
             messenger.send(self.uniqueName("disable"))
             self.disable()
             self.disable()
 
 
+    @calldownEnforced
     def announceGenerate(self):
     def announceGenerate(self):
         """
         """
         Sends a message to the world after the object has been
         Sends a message to the world after the object has been
@@ -210,6 +214,7 @@ class DistributedObject(DistributedObjectBase):
             self.activeState = ESGenerated
             self.activeState = ESGenerated
             messenger.send(self.uniqueName("generate"), [self])
             messenger.send(self.uniqueName("generate"), [self])
 
 
+    @calldownEnforced
     def disable(self):
     def disable(self):
         """
         """
         Inheritors should redefine this to take appropriate action on disable
         Inheritors should redefine this to take appropriate action on disable
@@ -238,6 +243,7 @@ class DistributedObject(DistributedObjectBase):
         assert self.notify.debugStateCall(self)
         assert self.notify.debugStateCall(self)
         return (self.activeState == ESGenerated)
         return (self.activeState == ESGenerated)
 
 
+    @calldownEnforced
     def delete(self):
     def delete(self):
         """
         """
         Inheritors should redefine this to take appropriate action on delete
         Inheritors should redefine this to take appropriate action on delete
@@ -249,7 +255,9 @@ class DistributedObject(DistributedObjectBase):
             self.DistributedObject_deleted = 1
             self.DistributedObject_deleted = 1
             self.cr = None
             self.cr = None
             self.dclass = None
             self.dclass = None
+            EnforcesCalldowns.destroy(self)
 
 
+    @calldownEnforced
     def generate(self):
     def generate(self):
         """
         """
         Inheritors should redefine this to take appropriate action on generate
         Inheritors should redefine this to take appropriate action on generate
@@ -257,11 +265,12 @@ class DistributedObject(DistributedObjectBase):
         assert self.notify.debugStateCall(self)
         assert self.notify.debugStateCall(self)
         self.activeState = ESGenerating
         self.activeState = ESGenerating
         # this has already been set at this point
         # this has already been set at this point
-        #self.cr.storeObjectLocation(self.doId, self.parentId, self.zoneId)
-        # HACK: we seem to be calling generate() more than once for objects that multiply-inherit
+        #self.cr.storeObjectLocation(self, self.parentId, self.zoneId)
+        # semi-hack: we seem to be calling generate() more than once for objects that multiply-inherit
         if not hasattr(self, '_autoInterestHandle'):
         if not hasattr(self, '_autoInterestHandle'):
             self.cr.openAutoInterests(self)
             self.cr.openAutoInterests(self)
 
 
+    @calldownEnforced
     def generateInit(self):
     def generateInit(self):
         """
         """
         This method is called when the DistributedObject is first introduced
         This method is called when the DistributedObject is first introduced

+ 207 - 93
direct/src/showbase/PythonUtil.py

@@ -2305,99 +2305,6 @@ class ArgumentEater:
     def __call__(self, *args, **kwArgs):
     def __call__(self, *args, **kwArgs):
         self._func(*args[self._numToEat:], **kwArgs)
         self._func(*args[self._numToEat:], **kwArgs)
 
 
-class HasCheckpoints:
-    """Derive from this class if you want to ensure that specific methods get called.
-    See checkpoint/etc. decorators below"""
-    def __init__(self):
-        self._checkpoints = {}
-        self._CPinit = True
-    def destroyCP(self):
-        del self._checkpoints
-    def _getCPFuncName(self, nameOrFunc):
-        if type(nameOrFunc) == types.StringType:
-            name = nameOrFunc
-        else:
-            func = nameOrFunc
-            name = '%s.%s' % (func.__module__, func.__name__)
-        return name
-    def _CPdestroyed(self):
-        # if self._CPinit does not exist here, our __init__ was not called
-        if self._CPinit and not hasattr(self, '_checkpoints'):
-            # this can occur if a decorated method is called after this
-            # base class has been destroyed.
-            return True
-        return False
-    def CPreset(self, func):
-        name = self._getCPFuncName(func)
-        if name in self._checkpoints:
-            del self._checkpoints[name]
-        return True
-    def CPvisit(self, func):
-        if self._CPdestroyed():
-            return True
-        name = self._getCPFuncName(func)
-        self._checkpoints.setdefault(name, 0)
-        self._checkpoints[name] += 1
-        return True
-    # check if a particular method was called
-    def CPcheck(self, func):
-        if self._CPdestroyed():
-            return True
-        name = self._getCPFuncName(func)
-        if self._checkpoints.get(name) is None:
-            __builtin__.tree = Functor(ClassTree, self)
-            raise ('\n%s not called for %s.%s\n call tree() to view %s class hierarchy' % (
-                name, self.__module__, self.__class__.__name__, self.__class__.__name__))
-        self.CPreset(name)
-        return True
-
-def checkpoint(f):
-    """
-    Use this decorator to track if a particular method has been called.
-    Class must derive from HasCheckpoints.
-    """
-    def _checkpoint(obj, *args, **kArgs):
-        obj.CPvisit(f)
-        return f(obj, *args, **kArgs)
-    _checkpoint.__doc__ = f.__doc__
-    _checkpoint.__name__ = f.__name__
-    _checkpoint.__module__ = f.__module__
-    return _checkpoint
-
-def calldownEnforced(f):
-    """
-    Use this decorator to ensure that derived classes that override this method
-    call down to the base class method.
-    """
-    # TODO
-    return f
-
-if __debug__:
-    class CheckPointTest(HasCheckpoints):
-        @checkpoint
-        def testFunc(self):
-            pass
-    cpt = CheckPointTest()
-    raised = True
-    try:
-        cpt.CPcheck(CheckPointTest.testFunc)
-        raised = False
-    except:
-        pass
-    if not raised:
-        raise 'CPcheck failed to raise'
-    cpt.testFunc()
-    cpt.CPcheck(CheckPointTest.testFunc)
-    try:
-        cpt.CPcheck('testFunc')
-        raised = False
-    except:
-        pass
-    if not raised:
-        raise 'CPcheck failed to raise'
-    del cpt
-    del CheckPointTest
-
 class ClassTree:
 class ClassTree:
     def __init__(self, instanceOrClass):
     def __init__(self, instanceOrClass):
         if type(instanceOrClass) in (types.ClassType, types.TypeType):
         if type(instanceOrClass) in (types.ClassType, types.TypeType):
@@ -2409,6 +2316,13 @@ class ClassTree:
         for base in self._cls.__bases__:
         for base in self._cls.__bases__:
             if base not in (types.ObjectType, types.TypeType):
             if base not in (types.ObjectType, types.TypeType):
                 self._bases.append(ClassTree(base))
                 self._bases.append(ClassTree(base))
+    def getAllClasses(self):
+        # returns set of this class and all base classes
+        classes = set()
+        classes.add(self._cls)
+        for base in self._bases:
+            classes.update(base.getAllClasses())
+        return classes
     def _getStr(self, indent=None, clsLeftAtIndent=None):
     def _getStr(self, indent=None, clsLeftAtIndent=None):
         # indent is how far to the right to indent (i.e. how many levels
         # indent is how far to the right to indent (i.e. how many levels
         # deep in the hierarchy from the most-derived)
         # deep in the hierarchy from the most-derived)
@@ -2433,6 +2347,16 @@ class ClassTree:
             s += ' +'
             s += ' +'
         s += self._cls.__name__
         s += self._cls.__name__
         clsLeftAtIndent[indent] -= 1
         clsLeftAtIndent[indent] -= 1
+        """
+        ### show the module to the right of the class name
+        moduleIndent = 48
+        if len(s) >= moduleIndent:
+            moduleIndent = (len(s) % 4) + 4
+        padding = moduleIndent - len(s)
+        s += padding * ' '
+        s += self._cls.__module__
+        ###
+        """
         if len(self._bases):
         if len(self._bases):
             newList = list(clsLeftAtIndent)
             newList = list(clsLeftAtIndent)
             newList.append(len(self._bases))
             newList.append(len(self._bases))
@@ -2445,6 +2369,196 @@ class ClassTree:
     def __repr__(self):
     def __repr__(self):
         return self._getStr()
         return self._getStr()
 
 
+class EnforcedCalldownException(Exception):
+    def __init__(self, what):
+        Exception.__init__(self, what)
+
+class EnforcesCalldowns:
+    """Derive from this class if you want to ensure that specific methods
+    get called.  See calldownEnforced decorator below"""
+
+    # class-level data for enforcement of base class method call-down
+    #
+    # The problem is that we don't have access to the class in the
+    # decorator, so we need to put the decorated methods in a global
+    # dict. We can then insert a stub method on each class instance for
+    # every method that has enforced base-class methods, and the stub can
+    # watch for each base-class method checkpoint to be passed.
+
+    # since the decorator can't know its own id until after it has been
+    # defined, we map from decorator ID to original func ID
+    _decoId2funcId = {}
+    # as calldownEnforced decorators are created, they add themselves to
+    # this dict. At this point we don't know what class they belong to.
+    _funcId2func = {}
+    # this is here so that we can print nice error messages
+    _funcId2class = {}
+    # as class instances are created, we populate this dictionary of
+    # class to func name to list of func ids. The lists of func ids
+    # include base-class funcs.
+    _class2funcName2funcIds = {}
+
+    # this method will be inserted into instances of classes that need
+    # to enforce base-class method calls, as the most-derived implementation
+    # of the method
+    @staticmethod
+    def _enforceCalldowns(oldMethod, name, obj, *args, **kArgs):
+        name2funcIds = EnforcesCalldowns._class2funcName2funcIds[obj.__class__]
+        funcIds = name2funcIds.get(name)
+
+        # prepare for the method call
+        for funcId in funcIds:
+            obj._EClatch(funcId)
+
+        # call the actual method that we're stubbing
+        result = oldMethod(*args, **kArgs)
+
+        # check on the results
+        for funcId in funcIds:
+            obj._ECcheck(funcId)
+
+        return result
+
+    def __init__(self):
+        if not __debug__:
+            return
+
+        # this map tracks how many times each func has been called
+        self._funcId2calls = {}
+        # this map tracks the 'latch' values for each func; if the call count
+        # for a func is greater than the latch, then the func has been called.
+        self._funcId2latch = {}
+
+        if self.__class__ not in EnforcesCalldowns._class2funcName2funcIds:
+            # prepare stubs to enforce method call-downs
+            EnforcesCalldowns._class2funcName2funcIds.setdefault(self.__class__, {})
+            # look through all of our base classes and find matches
+            classes = ClassTree(self).getAllClasses()
+            # collect IDs of all the enforced methods
+            funcId2func = {}
+            for cls in classes:
+                for name, item in cls.__dict__.items():
+                    if id(item) in EnforcesCalldowns._decoId2funcId:
+                        funcId = EnforcesCalldowns._decoId2funcId[id(item)]
+                        funcId2func[funcId] = item
+                        EnforcesCalldowns._funcId2class[funcId] = cls
+            # add these funcs to the list for our class
+            funcName2funcIds = EnforcesCalldowns._class2funcName2funcIds[self.__class__]
+            for funcId, func in funcId2func.items():
+                funcName2funcIds.setdefault(func.__name__, [])
+                funcName2funcIds[func.__name__].append(funcId)
+
+        # now run through all the enforced funcs for this class and insert
+        # stub methods to do the enforcement
+        funcName2funcIds = EnforcesCalldowns._class2funcName2funcIds[self.__class__]
+        self._obscuredMethodNames = set()
+        for name in funcName2funcIds:
+            oldMethod = getattr(self, name)
+            self._obscuredMethodNames.add(name)
+            setattr(self, name, new.instancemethod(
+                Functor(EnforcesCalldowns._enforceCalldowns, oldMethod, name),
+                self, self.__class__))
+            
+    def destroy(self):
+        # this must be called on destruction to prevent memory leaks
+        #import pdb;pdb.set_trace()
+        for name in self._obscuredMethodNames:
+            delattr(self, name)
+        del self._obscuredMethodNames
+        # this opens up more cans of worms. Let's keep it closed for the moment
+        #del self._funcId2calls
+        #del self._funcId2latch
+
+    def skipCalldown(self, method):
+        if not __debug__:
+            return
+        # Call this function if you really don't want to call down to an
+        # enforced base-class method. This should hardly ever be used.
+        funcName2funcIds = EnforcesCalldowns._class2funcName2funcIds[self.__class__]
+        funcIds = funcName2funcIds[method.__name__]
+        for funcId in funcIds:
+            self._ECvisit(funcId)
+
+    def _EClatch(self, funcId):
+        self._funcId2calls.setdefault(funcId, 0)
+        self._funcId2latch[funcId] = self._funcId2calls[funcId]
+    def _ECvisit(self, funcId):
+        self._funcId2calls.setdefault(funcId, 0)
+        self._funcId2calls[funcId] += 1
+    def _ECcheck(self, funcId):
+        if self._funcId2latch[funcId] == self._funcId2calls[funcId]:
+            func = EnforcesCalldowns._funcId2func[funcId]
+            raise EnforcedCalldownException(
+                '%s.%s did not call down to %s.%s\n%s' % (
+                self.__class__.__module__, self.__class__.__name__,
+                EnforcesCalldowns._funcId2class[funcId].__name__,
+                func.__name__,
+                ClassTree(self)))
+
+def calldownEnforced(f):
+    """
+    Use this decorator to ensure that derived classes that override this method
+    call down to the base class method.
+    """
+    if not __debug__:
+        return f
+    def calldownEnforcedImpl(obj, *args, **kArgs):
+        # track the fact that this func has been called
+        obj._ECvisit(id(f))
+        f(obj, *args, **kArgs)
+    calldownEnforcedImpl.__doc__ = f.__doc__
+    calldownEnforcedImpl.__name__ = f.__name__
+    calldownEnforcedImpl.__module__ = f.__module__
+    EnforcesCalldowns._decoId2funcId[id(calldownEnforcedImpl)] = id(f)
+    EnforcesCalldowns._funcId2func[id(f)] = calldownEnforcedImpl
+    return calldownEnforcedImpl
+
+if __debug__:
+    class CalldownEnforceTest(EnforcesCalldowns):
+        @calldownEnforced
+        def testFunc(self):
+            pass
+    class CalldownEnforceTestSubclass(CalldownEnforceTest):
+        def testFunc(self):
+            CalldownEnforceTest.testFunc(self)
+    class CalldownEnforceTestSubclassFail(CalldownEnforceTest):
+        def testFunc(self):
+            pass
+    class CalldownEnforceTestSubclassSkip(CalldownEnforceTest):
+        def testFunc(self):
+            self.skipCalldown(CalldownEnforceTest.testFunc)
+    cets = CalldownEnforceTestSubclass()
+    cetsf = CalldownEnforceTestSubclassFail()
+    cetss = CalldownEnforceTestSubclassSkip()
+    raised = False
+    try:
+        cets.testFunc()
+    except EnforcedCalldownException, e:
+        raised = True
+    if raised:
+        raise "calldownEnforced raised when it shouldn't"
+    raised = False
+    try:
+        cetsf.testFunc()
+    except EnforcedCalldownException, e:
+        raised = True
+    if not raised:
+        raise 'calldownEnforced failed to raise'
+    raised = False
+    try:
+        cetss.testFunc()
+    except EnforcedCalldownException, e:
+        raised = True
+    if raised:
+        raise "calldownEnforced.skipCalldown raised when it shouldn't"
+    del cetss
+    del cetsf
+    del cets
+    del CalldownEnforceTestSubclassSkip
+    del CalldownEnforceTestSubclassFail
+    del CalldownEnforceTestSubclass
+    del CalldownEnforceTest
+
 import __builtin__
 import __builtin__
 __builtin__.Functor = Functor
 __builtin__.Functor = Functor
 __builtin__.Stack = Stack
 __builtin__.Stack = Stack