Browse Source

express: distinguish between null vs empty in CPTA pickle as well

rdb 5 years ago
parent
commit
a84f1b5595
2 changed files with 33 additions and 1 deletions
  1. 8 1
      panda/src/express/pointerToArray_ext.I
  2. 25 0
      tests/express/test_pointertoarray.py

+ 8 - 1
panda/src/express/pointerToArray_ext.I

@@ -329,7 +329,14 @@ get_subdata(size_t n, size_t count) const {
 template<class Element>
 template<class Element>
 INLINE PyObject *Extension<ConstPointerToArray<Element> >::
 INLINE PyObject *Extension<ConstPointerToArray<Element> >::
 __reduce__(PyObject *self) const {
 __reduce__(PyObject *self) const {
-  return Py_BuildValue("O(N)", Py_TYPE(self), get_data());
+  // This preserves the distinction between a null vs. an empty PTA, though I'm
+  // not sure that this distinction matters to anyone.
+  if (!this->_this->is_null() && this->_this->empty()) {
+    return Py_BuildValue("O([])", Py_TYPE(self));
+  }
+  else {
+    return Py_BuildValue("O(N)", Py_TYPE(self), get_data());
+  }
 }
 }
 
 
 /**
 /**

+ 25 - 0
tests/express/test_pointertoarray.py

@@ -44,3 +44,28 @@ def test_pta_float_pickle():
         data_pta2 = loads(dumps(data_pta, proto))
         data_pta2 = loads(dumps(data_pta, proto))
         assert tuple(data_pta2) == (1.0, 2.0, 3.0)
         assert tuple(data_pta2) == (1.0, 2.0, 3.0)
         assert data_pta2.get_data() == data_pta.get_data()
         assert data_pta2.get_data() == data_pta.get_data()
+
+
+def test_cpta_float_pickle():
+    from panda3d.core import PTA_float, CPTA_float
+    from direct.stdpy.pickle import dumps, loads, HIGHEST_PROTOCOL
+
+    null_pta = CPTA_float(PTA_float())
+
+    empty_pta = CPTA_float([])
+
+    data_pta = CPTA_float([1.0, 2.0, 3.0])
+    data = data_pta.get_data()
+
+    for proto in range(1, HIGHEST_PROTOCOL + 1):
+        null_pta2 = loads(dumps(null_pta, proto))
+        assert null_pta2.is_null()
+        assert len(null_pta2) == 0
+
+        empty_pta2 = loads(dumps(empty_pta, proto))
+        assert not empty_pta2.is_null()
+        assert len(empty_pta2) == 0
+
+        data_pta2 = loads(dumps(data_pta, proto))
+        assert tuple(data_pta2) == (1.0, 2.0, 3.0)
+        assert data_pta2.get_data() == data_pta.get_data()