Browse Source

POD: fixed compilation of DataSet wrt inheritance, removed unnecessary getter default value handling, better documentation, more testing

Darren Ranalli 18 years ago
parent
commit
a5502e3b43
1 changed files with 54 additions and 13 deletions
  1. 54 13
      direct/src/showbase/PythonUtil.py

+ 54 - 13
direct/src/showbase/PythonUtil.py

@@ -1467,7 +1467,7 @@ class POD:
         # appear here as 'name': defaultValue,
         # appear here as 'name': defaultValue,
         #
         #
         # WARNING: default values of mutable types that do not copy by
         # WARNING: default values of mutable types that do not copy by
-        # value (dicts, lists etc.) will be shared by all class instances
+        # value (dicts, lists etc.) will be shared by all class instances.
         # if default value is callable, it will be called to get actual
         # if default value is callable, it will be called to get actual
         # default value
         # default value
         #
         #
@@ -1481,11 +1481,16 @@ class POD:
     def __init__(self, **kwArgs):
     def __init__(self, **kwArgs):
         self.__class__._compileDefaultDataSet()
         self.__class__._compileDefaultDataSet()
         if __debug__:
         if __debug__:
+            # make sure all of the keyword arguments passed in
+            # are present in our data set
             for arg in kwArgs.keys():
             for arg in kwArgs.keys():
                 assert arg in self.getDataNames(), (
                 assert arg in self.getDataNames(), (
                     "unknown argument for %s: '%s'" % (
                     "unknown argument for %s: '%s'" % (
                     self.__class__, arg))
                     self.__class__, arg))
+        # assign each of our data items directly to self
         for name in self.getDataNames():
         for name in self.getDataNames():
+            # if a value has been passed in for a data item, use
+            # that value, otherwise use the default value
             if name in kwArgs:
             if name in kwArgs:
                 getSetter(self, name)(kwArgs[name])
                 getSetter(self, name)(kwArgs[name])
             else:
             else:
@@ -1531,6 +1536,13 @@ class POD:
     def getDefaultValue(cls, name):
     def getDefaultValue(cls, name):
         cls._compileDefaultDataSet()
         cls._compileDefaultDataSet()
         dv = cls._DataSet[name]
         dv = cls._DataSet[name]
+        # this allows us to create a new mutable object every time we ask
+        # for its default value, i.e. if the default value is dict, this
+        # method will return a new empty dictionary object every time. This
+        # will cause problems if the intent is to store a callable object
+        # as the default value itself; we need a way to specify that the
+        # callable *is* the default value and not a default-value creation
+        # function
         if callable(dv):
         if callable(dv):
             dv = dv()
             dv = dv()
         return dv
         return dv
@@ -1549,24 +1561,26 @@ class POD:
                     cls.__dict__[setterName] = defaultSetter
                     cls.__dict__[setterName] = defaultSetter
                 getterName = getSetterName(name, 'get')
                 getterName = getSetterName(name, 'get')
                 if not hasattr(cls, getterName):
                 if not hasattr(cls, getterName):
-                    def defaultGetter(self, name=name,
-                                      default=cls.DataSet[name]):
-                        return getattr(self, name, default)
+                    def defaultGetter(self, name=name):
+                        return getattr(self, name)
                     cls.__dict__[getterName] = defaultGetter
                     cls.__dict__[getterName] = defaultGetter
         # this dict will hold all of the aggregated default data values for
         # this dict will hold all of the aggregated default data values for
         # this particular class, including values from its base classes
         # this particular class, including values from its base classes
         cls._DataSet = {}
         cls._DataSet = {}
         bases = list(cls.__bases__)
         bases = list(cls.__bases__)
-        # bring less-derived classes to the front
-        mostDerivedLast(bases)
-        for c in (bases + [cls]):
+        # process in reverse of inheritance order, so that base classes listed first
+        # will take precedence over later base classes
+        bases.reverse()
+        for curBase in bases:
             # skip multiple-inheritance base classes that do not derive from POD
             # skip multiple-inheritance base classes that do not derive from POD
-            if issubclass(c, POD):
+            if issubclass(curBase, POD):
                 # make sure this base has its dict of data defaults
                 # make sure this base has its dict of data defaults
-                c._compileDefaultDataSet()
-                if c.__dict__.has_key('DataSet'):
-                    # apply this class' default data values to our dict
-                    cls._DataSet.update(c.DataSet)
+                curBase._compileDefaultDataSet()
+                # grab all inherited data default values
+                cls._DataSet.update(curBase._DataSet)
+        # pull in our own class' default values if any are specified
+        if 'DataSet' in cls.__dict__:
+            cls._DataSet.update(cls.DataSet)
 
 
     def __repr__(self):
     def __repr__(self):
         argStr = ''
         argStr = ''
@@ -1581,8 +1595,35 @@ if __debug__:
             }
             }
     p1 = PODtest()
     p1 = PODtest()
     p2 = PODtest()
     p2 = PODtest()
-    p1.foo[1] = 2
+    assert hasattr(p1, 'foo')
+    # make sure the getter is working
+    assert p1.getFoo() is p1.foo
+    p1.getFoo()[1] = 2
+    assert p1.foo[1] == 2
+    # make sure that each instance gets its own copy of a mutable
+    # data item
+    assert p1.foo is not p2.foo
+    assert len(p1.foo) == 1
     assert len(p2.foo) == 0
     assert len(p2.foo) == 0
+    # make sure the setter is working
+    p2.setFoo({10:20})
+    assert p2.foo[10] == 20
+    # make sure modifications to mutable data items don't affect other
+    # instances
+    assert p1.foo[1] == 2
+
+    class DerivedPOD(PODtest):
+        DataSet = {
+            'bar': list,
+            }
+    d1 = DerivedPOD()
+    # make sure that derived instances get their own copy of mutable
+    # data items
+    assert hasattr(d1, 'foo')
+    assert len(d1.foo) == 0
+    # make sure derived instances get their own items
+    assert hasattr(d1, 'bar')
+    assert len(d1.bar) == 0
 
 
 def bound(value, bound1, bound2):
 def bound(value, bound1, bound2):
     """
     """