Bladeren bron

stdpy: Fix direct.stdpy.pickle module for Python 3

rdb 5 jaren geleden
bovenliggende
commit
2c209e0f02
2 gewijzigde bestanden met toevoegingen van 57 en 10 verwijderingen
  1. 46 10
      direct/src/stdpy/pickle.py
  2. 11 0
      tests/stdpy/test_pickle.py

+ 46 - 10
direct/src/stdpy/pickle.py

@@ -33,19 +33,30 @@ else:
 # with the local pickle.py.
 pickle = __import__('pickle')
 
-class Pickler(pickle.Pickler):
+if sys.version_info >= (3, 0):
+    BasePickler = pickle._Pickler
+    BaseUnpickler = pickle._Unpickler
+else:
+    BasePickler = pickle.Pickler
+    BaseUnpickler = pickle.Unpickler
+
+
+class _Pickler(BasePickler):
 
     def __init__(self, *args, **kw):
         self.bamWriter = BamWriter()
-        pickle.Pickler.__init__(self, *args, **kw)
+        BasePickler.__init__(self, *args, **kw)
 
     # We have to duplicate most of the save() method, so we can add
     # support for __reduce_persist__().
 
-    def save(self, obj):
+    def save(self, obj, save_persistent_id=True):
+        if self.proto >= 4:
+            self.framer.commit_frame()
+
         # Check for persistent id (defined by a subclass)
         pid = self.persistent_id(obj)
-        if pid:
+        if pid is not None and save_persistent_id:
             self.save_pers(pid)
             return
 
@@ -112,11 +123,12 @@ class Pickler(pickle.Pickler):
         # Save the reduce() output and finally memoize the object
         self.save_reduce(obj=obj, *rv)
 
-class Unpickler(pickle.Unpickler):
+
+class Unpickler(BaseUnpickler):
 
     def __init__(self, *args, **kw):
         self.bamReader = BamReader()
-        pickle.Unpickler.__init__(self, *args, **kw)
+        BaseUnpickler.__init__(self, *args, **kw)
 
     # Duplicate the load_reduce() function, to provide a special case
     # for the reduction function.
@@ -126,9 +138,10 @@ class Unpickler(pickle.Unpickler):
         args = stack.pop()
         func = stack[-1]
 
-        # If the function name ends with "Persist", then assume the
+        # If the function name ends with "_persist", then assume the
         # function wants the Unpickler as the first parameter.
-        if func.__name__.endswith('Persist'):
+        func_name = func.__name__
+        if func_name.endswith('_persist') or func_name.endswith('Persist'):
             value = func(self, *args)
         else:
             # Otherwise, use the existing pickle convention.
@@ -136,9 +149,32 @@ class Unpickler(pickle.Unpickler):
 
         stack[-1] = value
 
-    #FIXME: how to replace in Python 3?
+    if sys.version_info >= (3, 0):
+        BaseUnpickler.dispatch[pickle.REDUCE[0]] = load_reduce
+    else:
+        BaseUnpickler.dispatch[pickle.REDUCE] = load_reduce
+
+
+if sys.version_info >= (3, 8):
+    # In Python 3.8 and up, we can use the C implementation of Pickler, which
+    # supports a reducer_override method.
+    class Pickler(pickle.Pickler):
+        def __init__(self, *args, **kw):
+            self.bamWriter = BamWriter()
+            pickle.Pickler.__init__(self, *args, **kw)
+
+        def reducer_override(self, obj):
+            reduce = getattr(obj, "__reduce_persist__", None)
+            if reduce:
+                return reduce(self)
+
+            return NotImplemented
+else:
+    # Otherwise, we have to use our custom version that overrides save().
+    Pickler = _Pickler
+
     if sys.version_info < (3, 0):
-        pickle.Unpickler.dispatch[pickle.REDUCE] = load_reduce
+        del _Pickler
 
 
 # Shorthands

+ 11 - 0
tests/stdpy/test_pickle.py

@@ -0,0 +1,11 @@
+from direct.stdpy.pickle import dumps, loads
+
+
+def test_reduce_persist():
+    from panda3d.core import NodePath
+
+    parent = NodePath("parent")
+    child = parent.attach_new_node("child")
+
+    parent2, child2 = loads(dumps([parent, child]))
+    assert tuple(parent2.children) == (child2,)