pickle.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. """ This module extends standard Python's pickle module so that it is
  2. capable of writing more efficient pickle files that contain Panda
  3. objects with shared pointers. In particular, a single Python
  4. structure that contains many NodePaths into the same scene graph will
  5. write the NodePaths correctly when used with this pickle module, so
  6. that when it is unpickled later, the NodePaths will still reference
  7. into the same scene graph together.
  8. If you use the standard pickle module instead, the NodePaths will each
  9. duplicate its own copy of its scene graph.
  10. This is necessary because the standard pickle module doesn't provide a
  11. mechanism for sharing context between different objects written to the
  12. same pickle stream, so each NodePath has to write itself without
  13. knowing about the other NodePaths that will also be writing to the
  14. same stream. This replacement module solves this problem by defining
  15. a ``__reduce_persist__()`` replacement method for ``__reduce__()``,
  16. which accepts a pointer to the Pickler object itself, allowing for
  17. shared context between all objects written by that Pickler.
  18. Unfortunately, cPickle cannot be supported, because it does not
  19. support extensions of this nature. """
  20. import sys
  21. from panda3d.core import BamWriter, BamReader
  22. from copyreg import dispatch_table
  23. # A funny replacement for "import pickle" so we don't get confused
  24. # with the local pickle.py.
  25. pickle = __import__('pickle')
  26. PicklingError = pickle.PicklingError
  27. BasePickler = pickle._Pickler
  28. BaseUnpickler = pickle._Unpickler
  29. class _Pickler(BasePickler):
  30. def __init__(self, *args, **kw):
  31. self.bamWriter = BamWriter()
  32. BasePickler.__init__(self, *args, **kw)
  33. # We have to duplicate most of the save() method, so we can add
  34. # support for __reduce_persist__().
  35. def save(self, obj, save_persistent_id=True):
  36. if self.proto >= 4:
  37. self.framer.commit_frame()
  38. # Check for persistent id (defined by a subclass)
  39. pid = self.persistent_id(obj)
  40. if pid is not None and save_persistent_id:
  41. self.save_pers(pid)
  42. return
  43. # Check the memo
  44. x = self.memo.get(id(obj))
  45. if x:
  46. self.write(self.get(x[0]))
  47. return
  48. # Check the type dispatch table
  49. t = type(obj)
  50. f = self.dispatch.get(t)
  51. if f:
  52. f(self, obj) # Call unbound method with explicit self
  53. return
  54. # Check for a class with a custom metaclass; treat as regular class
  55. try:
  56. issc = issubclass(t, type)
  57. except TypeError: # t is not a class (old Boost; see SF #502085)
  58. issc = 0
  59. if issc:
  60. self.save_global(obj)
  61. return
  62. # Check copy_reg.dispatch_table
  63. reduce = dispatch_table.get(t)
  64. if reduce:
  65. rv = reduce(obj)
  66. else:
  67. # New code: check for a __reduce_persist__ method, then
  68. # fall back to standard methods.
  69. reduce = getattr(obj, "__reduce_persist__", None)
  70. if reduce:
  71. rv = reduce(self)
  72. else:
  73. # Check for a __reduce_ex__ method, fall back to __reduce__
  74. reduce = getattr(obj, "__reduce_ex__", None)
  75. if reduce:
  76. rv = reduce(self.proto)
  77. else:
  78. reduce = getattr(obj, "__reduce__", None)
  79. if reduce:
  80. rv = reduce()
  81. else:
  82. raise PicklingError("Can't pickle %r object: %r" %
  83. (t.__name__, obj))
  84. # Check for string returned by reduce(), meaning "save as global"
  85. if type(rv) is str:
  86. self.save_global(obj, rv)
  87. return
  88. # Assert that reduce() returned a tuple
  89. if type(rv) is not tuple:
  90. raise PicklingError("%s must return string or tuple" % reduce)
  91. # Assert that it returned an appropriately sized tuple
  92. l = len(rv)
  93. if not (2 <= l <= 5):
  94. raise PicklingError("Tuple returned by %s must have "
  95. "two to five elements" % reduce)
  96. # Save the reduce() output and finally memoize the object
  97. self.save_reduce(obj=obj, *rv)
  98. class Unpickler(BaseUnpickler):
  99. def __init__(self, *args, **kw):
  100. self.bamReader = BamReader()
  101. BaseUnpickler.__init__(self, *args, **kw)
  102. # Duplicate the load_reduce() function, to provide a special case
  103. # for the reduction function.
  104. def load_reduce(self):
  105. stack = self.stack
  106. args = stack.pop()
  107. func = stack[-1]
  108. # If the function name ends with "_persist", then assume the
  109. # function wants the Unpickler as the first parameter.
  110. func_name = func.__name__
  111. if func_name.endswith('_persist') or func_name.endswith('Persist'):
  112. value = func(self, *args)
  113. else:
  114. # Otherwise, use the existing pickle convention.
  115. value = func(*args)
  116. stack[-1] = value
  117. BaseUnpickler.dispatch[pickle.REDUCE[0]] = load_reduce
  118. if sys.version_info >= (3, 8):
  119. # In Python 3.8 and up, we can use the C implementation of Pickler, which
  120. # supports a reducer_override method.
  121. class Pickler(pickle.Pickler):
  122. def __init__(self, *args, **kw):
  123. self.bamWriter = BamWriter()
  124. pickle.Pickler.__init__(self, *args, **kw)
  125. def reducer_override(self, obj):
  126. reduce = getattr(obj, "__reduce_persist__", None)
  127. if reduce:
  128. return reduce(self)
  129. return NotImplemented
  130. else:
  131. # Otherwise, we have to use our custom version that overrides save().
  132. Pickler = _Pickler
  133. # Shorthands
  134. from io import BytesIO
  135. def dump(obj, file, protocol=None):
  136. Pickler(file, protocol).dump(obj)
  137. def dumps(obj, protocol=None):
  138. file = BytesIO()
  139. Pickler(file, protocol).dump(obj)
  140. return file.getvalue()
  141. def load(file):
  142. return Unpickler(file).load()
  143. def loads(str):
  144. file = BytesIO(str)
  145. return Unpickler(file).load()