Browse Source

Implement support for custom TypedWritable subclasses in Python

Fixes #1495
rdb 2 years ago
parent
commit
115a716a7d

+ 1 - 0
dtool/src/dtoolbase/dtoolbase.h

@@ -130,6 +130,7 @@
 // flag to link in Python, we'll just excerpt the forward declaration of
 // PyObject.
 typedef struct _object PyObject;
+typedef struct _typeobject PyTypeObject;
 
 #ifndef HAVE_EIGEN
 // If we don't have the Eigen library, don't define LINMATH_ALIGN.

+ 17 - 1
dtool/src/dtoolbase/typeHandle.cxx

@@ -155,7 +155,7 @@ deallocate_array(void *ptr) {
 /**
  * Returns the internal void pointer that is stored for interrogate's benefit.
  */
-PyObject *TypeHandle::
+PyTypeObject *TypeHandle::
 get_python_type() const {
   TypeRegistryNode *rnode = TypeRegistry::ptr()->look_up(*this, nullptr);
   if (rnode != nullptr) {
@@ -164,6 +164,22 @@ get_python_type() const {
     return nullptr;
   }
 }
+
+/**
+ * Returns a Python wrapper object corresponding to the given C++ pointer.
+ */
+PyObject *TypeHandle::
+wrap_python(void *ptr, PyTypeObject *cast_from) const {
+  if (ptr == nullptr) {
+    return nullptr;
+  }
+  TypeRegistryNode *rnode = TypeRegistry::ptr()->look_up(*this, nullptr);
+  if (rnode != nullptr) {
+    return rnode->wrap_python(ptr, cast_from);
+  } else {
+    return nullptr;
+  }
+}
 #endif
 
 std::ostream &

+ 2 - 1
dtool/src/dtoolbase/typeHandle.h

@@ -140,7 +140,8 @@ PUBLISHED:
 
 public:
 #ifdef HAVE_PYTHON
-  PyObject *get_python_type() const;
+  PyTypeObject *get_python_type() const;
+  PyObject *wrap_python(void *ptr, PyTypeObject *cast_from = nullptr) const;
 #endif // HAVE_PYTHON
 
   void *allocate_array(size_t size) RETURNS_ALIGNED(MEMORY_HOOK_ALIGNMENT);

+ 20 - 1
dtool/src/dtoolbase/typeHandle_ext.cxx

@@ -46,7 +46,26 @@ __reduce__() const {
 
   // If we have a Python binding registered for it, that's the preferred method,
   // since it ensures that the appropriate module gets loaded by pickle.
-  PyObject *py_type = _this->get_python_type();
+  PyTypeObject *py_type = _this->get_python_type();
+  if (py_type != nullptr && py_type->tp_dict != nullptr) {
+    // Look for a get_class_type method, if it returns this handle.
+    PyObject *func = PyDict_GetItemString(py_type->tp_dict, "get_class_type");
+    if (func != nullptr && PyCallable_Check(func)) {
+      PyObject *result = PyObject_CallNoArgs(func);
+      TypeHandle *result_handle = nullptr;
+      if (result == nullptr) {
+        // Never mind.
+        PyErr_Clear();
+      }
+      else if (DtoolInstance_GetPointer(result, result_handle, Dtool_TypeHandle) &&
+               *result_handle == *_this) {
+        // It returned the correct result, so we can use this.
+        return Py_BuildValue("O()", func);
+      }
+    }
+  }
+
+  // Fall back to TypeHandle::make(), if would produce the correct result.
   if (py_type != nullptr && *_this == ((Dtool_PyTypedObject *)py_type)->_type) {
     PyObject *func = PyObject_GetAttrString((PyObject *)&Dtool_TypeHandle, "make");
     return Py_BuildValue("N(O)", func, py_type);

+ 3 - 2
dtool/src/dtoolbase/typeRegistry.cxx

@@ -213,12 +213,13 @@ record_alternate_name(TypeHandle type, const string &name) {
  * of interrogate, which expects this to contain a Dtool_PyTypedObject.
  */
 void TypeRegistry::
-record_python_type(TypeHandle type, PyObject *python_type) {
+record_python_type(TypeHandle type, PyTypeObject *cls, PythonWrapFunc *wrap_func) {
   _lock.lock();
 
   TypeRegistryNode *rnode = look_up(type, nullptr);
   if (rnode != nullptr) {
-    rnode->_python_type = python_type;
+    rnode->_python_type = cls;
+    rnode->_python_wrap_func = wrap_func;
   }
 
   _lock.unlock();

+ 6 - 1
dtool/src/dtoolbase/typeRegistry.h

@@ -40,13 +40,18 @@ public:
   // convenience function, defined in register_type.h.
   bool register_type(TypeHandle &type_handle, const std::string &name);
 
+#ifdef HAVE_PYTHON
+  typedef PyObject *PythonWrapFunc(void *ptr, PyTypeObject *cast_from);
+#endif
+
 PUBLISHED:
   TypeHandle register_dynamic_type(const std::string &name);
 
   void record_derivation(TypeHandle child, TypeHandle parent);
   void record_alternate_name(TypeHandle type, const std::string &name);
 #ifdef HAVE_PYTHON
-  void record_python_type(TypeHandle type, PyObject *python_type);
+  void record_python_type(TypeHandle type, PyTypeObject *cls,
+                          PythonWrapFunc *wrap_func);
 #endif
 
   TypeHandle find_type(const std::string &name) const;

+ 18 - 1
dtool/src/dtoolbase/typeRegistryNode.I

@@ -14,7 +14,7 @@
 /**
  * Returns the Python type object associated with this node.
  */
-INLINE PyObject *TypeRegistryNode::
+INLINE PyTypeObject *TypeRegistryNode::
 get_python_type() const {
   if (_python_type != nullptr || _parent_classes.empty()) {
     return _python_type;
@@ -24,6 +24,23 @@ get_python_type() const {
   }
 }
 
+/**
+ * Returns a Python wrapper object corresponding to the given C++ pointer.
+ */
+INLINE PyObject *TypeRegistryNode::
+wrap_python(void *ptr, PyTypeObject *cast_from) const {
+  if (_python_wrap_func != nullptr) {
+    return _python_wrap_func(ptr, cast_from);
+  }
+  else if (_parent_classes.empty()) {
+    return nullptr;
+  }
+  else {
+    // Recurse through parent classes.
+    return r_wrap_python(ptr, cast_from);
+  }
+}
+
 /**
  *
  */

+ 27 - 7
dtool/src/dtoolbase/typeRegistryNode.cxx

@@ -311,16 +311,14 @@ r_build_subtrees(TypeRegistryNode *top, int bit_count,
  * Recurses through the parent nodes to find the best Python type object to
  * represent objects of this type.
  */
-PyObject *TypeRegistryNode::
+PyTypeObject *TypeRegistryNode::
 r_get_python_type() const {
-  Classes::const_iterator ni;
-  for (ni = _parent_classes.begin(); ni != _parent_classes.end(); ++ni) {
-    const TypeRegistryNode *parent = *ni;
+  for (const TypeRegistryNode *parent : _parent_classes) {
     if (parent->_python_type != nullptr) {
       return parent->_python_type;
-
-    } else if (!parent->_parent_classes.empty()) {
-      PyObject *py_type = parent->r_get_python_type();
+    }
+    else if (!parent->_parent_classes.empty()) {
+      PyTypeObject *py_type = parent->r_get_python_type();
       if (py_type != nullptr) {
         return py_type;
       }
@@ -330,6 +328,28 @@ r_get_python_type() const {
   return nullptr;
 }
 
+/**
+ * Creates a Python wrapper object to represent the given C++ pointer, which
+ * must be exactly of the correct type, unless cast_from is set to one of its
+ * base classes, in which case it will be cast appropriately.
+ */
+PyObject *TypeRegistryNode::
+r_wrap_python(void *ptr, PyTypeObject *cast_from) const {
+  for (const TypeRegistryNode *parent : _parent_classes) {
+    if (parent->_python_wrap_func != nullptr) {
+      return parent->_python_wrap_func(ptr, cast_from);
+    }
+    else if (!parent->_parent_classes.empty()) {
+      PyObject *wrapper = parent->r_wrap_python(ptr, cast_from);
+      if (wrapper != nullptr) {
+        return wrapper;
+      }
+    }
+  }
+
+  return nullptr;
+}
+
 /**
  * A recursive function to double-check the result of is_derived_from().  This
  * is the slow, examine-the-whole-graph approach, as opposed to the clever and

+ 8 - 3
dtool/src/dtoolbase/typeRegistryNode.h

@@ -30,6 +30,8 @@
  */
 class EXPCL_DTOOL_DTOOLBASE TypeRegistryNode {
 public:
+  typedef PyObject *PythonWrapFunc(void *ptr, PyTypeObject *cast_from);
+
   TypeRegistryNode(TypeHandle handle, const std::string &name, TypeHandle &ref);
 
   static bool is_derived_from(const TypeRegistryNode *child,
@@ -38,7 +40,8 @@ public:
   static TypeHandle get_parent_towards(const TypeRegistryNode *child,
                                        const TypeRegistryNode *base);
 
-  INLINE PyObject *get_python_type() const;
+  INLINE PyTypeObject *get_python_type() const;
+  INLINE PyObject *wrap_python(void *ptr, PyTypeObject *cast_from) const;
 
   void clear_subtree();
   void define_subtree();
@@ -49,7 +52,8 @@ public:
   typedef std::vector<TypeRegistryNode *> Classes;
   Classes _parent_classes;
   Classes _child_classes;
-  PyObject *_python_type = nullptr;
+  PyTypeObject *_python_type = nullptr;
+  PythonWrapFunc *_python_wrap_func = nullptr;
 
   patomic<size_t> _memory_usage[TypeHandle::MC_limit];
 
@@ -81,7 +85,8 @@ private:
   void r_build_subtrees(TypeRegistryNode *top,
                         int bit_count, SubtreeMaskType bits);
 
-  PyObject *r_get_python_type() const;
+  PyTypeObject *r_get_python_type() const;
+  PyObject *r_wrap_python(void *ptr, PyTypeObject *cast_from) const;
 
   static bool check_derived_from(const TypeRegistryNode *child,
                                  const TypeRegistryNode *base);

+ 43 - 15
dtool/src/interrogate/interfaceMakerPythonNative.cxx

@@ -115,6 +115,7 @@ RenameSet methodRenameDictionary[] = {
   { "__deepcopy__"  , "__deepcopy__",           0 },
   { "__getstate__"  , "__getstate__",           0 },
   { "__setstate__"  , "__setstate__",           0 },
+  { "__new__"       , "__new__",                0 },
   { "print"         , "Cprint",                 0 },
   { "CInterval.set_t", "_priv__cSetT",          0 },
   { nullptr, nullptr, -1 }
@@ -664,6 +665,12 @@ get_slotted_function_def(Object *obj, Function *func, FunctionRemap *remap,
     return true;
   }
 
+  if (method_name == "__new__") {
+    def._answer_location = "tp_new";
+    def._wrapper_type = WT_new;
+    return true;
+  }
+
   if (remap->_type == FunctionRemap::T_typecast_method) {
     // A typecast operator.  Check for a supported low-level typecast type.
     if (TypeManager::is_bool(remap->_return_type->get_orig_type())) {
@@ -1106,10 +1113,14 @@ write_class_details(ostream &out, Object *obj) {
   out << " */\n";
 
   // First write out all the wrapper functions for the methods.
+  bool have_new = false;
   for (Function *func : obj->_methods) {
     if (func) {
       // Write the definition of the generic wrapper function for this
       // function.
+      if (func->_ifunc.get_name() == "__new__") {
+        have_new = true;
+      }
       write_function_for_top(out, obj, func);
     }
   }
@@ -1128,10 +1139,16 @@ write_class_details(ostream &out, Object *obj) {
   if (obj->_constructors.size() == 0) {
     // We still need to write a dummy constructor to prevent inheriting the
     // constructor from a base class.
-    out << fname << " {\n"
-      "  Dtool_Raise_TypeError(\"cannot init abstract class\");\n"
-      "  return -1;\n"
-      "}\n\n";
+    if (have_new) {
+      out << fname << " {\n"
+        "  return 0;\n"
+        "}\n\n";
+    } else {
+      out << fname << " {\n"
+        "  Dtool_Raise_TypeError(\"cannot init abstract class\");\n"
+        "  return -1;\n"
+        "}\n\n";
+    }
   }
 
   CPPType *cpptype = TypeManager::resolve_type(obj->_itype._cpptype);
@@ -1221,18 +1238,17 @@ write_class_details(ostream &out, Object *obj) {
     out << "  return nullptr;\n";
     out << "}\n\n";
 
-    out << "static Dtool_PyInstDef *Dtool_Wrap_" << ClassName << "(void *from_this, Dtool_PyTypedObject *from_type) {\n";
-    out << "  if (from_this == nullptr || from_type == nullptr) {\n";
-    out << "    return nullptr;\n";
-    out << "  }\n";
+    //NB. This may be called with nullptr in either argument and should produce
+    // a valid wrapper object even with a null pointer.
+    out << "static PyObject *Dtool_Wrap_" << ClassName << "(void *from_this, PyTypeObject *from_type) {\n";
     out << "  " << cClassName << " *to_this;\n";
-    out << "  if (from_type == &Dtool_" << ClassName << ") {\n";
+    out << "  if (from_type == nullptr || from_type == &Dtool_" << ClassName << "._PyType) {\n";
     out << "    to_this = (" << cClassName << "*)from_this;\n";
     out << "  }\n";
     for (di = details.begin(); di != details.end(); di++) {
       if (di->second._can_downcast && di->second._is_legal_py_class) {
-        out << "  else if (from_type == Dtool_Ptr_" << make_safe_name(di->second._to_class_name) << ") {\n";
-        out << "    " << di->second._to_class_name << "* other_this = (" << di->second._to_class_name << "*)from_this;\n" ;
+        out << "  else if (from_type == (PyTypeObject *)Dtool_Ptr_" << make_safe_name(di->second._to_class_name) << ") {\n";
+        out << "    " << di->second._to_class_name << " *other_this = (" << di->second._to_class_name << " *)from_this;\n" ;
         out << "    to_this = (" << cClassName << "*)other_this;\n";
         out << "  }\n";
       }
@@ -1247,7 +1263,7 @@ write_class_details(ostream &out, Object *obj) {
     out << "  self->_ptr_to_object = to_this;\n";
     out << "  self->_memory_rules = false;\n";
     out << "  self->_is_const = false;\n";
-    out << "  return self;\n";
+    out << "  return (PyObject *)self;\n";
     out << "}\n\n";
   }
 }
@@ -1405,8 +1421,9 @@ write_module_support(ostream &out, ostream *out_h, InterrogateModuleDef *def) {
           out << "    TypeHandle handle = " << type->get_local_name(&parser)
               << "::get_class_type();\n";
           out << "    Dtool_" << safe_name << "._type = handle;\n";
-          out << "    registry->record_python_type(handle, "
-                 "(PyObject *)&Dtool_" << safe_name << ");\n";
+          out << "    registry->record_python_type(handle,"
+                 " &Dtool_" << safe_name << "._PyType,"
+                 " Dtool_Wrap_" << safe_name << ");\n";
           out << "  }\n";
         } else {
           if (IsPandaTypedObject(type->as_struct_type())) {
@@ -2624,6 +2641,17 @@ write_module_class(ostream &out, Object *obj) {
         }
         break;
 
+      case WT_new:
+        {
+          string fname = "static PyObject *" + def._wrapper_name + "(PyTypeObject *cls, PyObject *args, PyObject *kwds)\n";
+
+          std::vector<FunctionRemap *> remaps;
+          remaps.insert(remaps.end(), def._remaps.begin(), def._remaps.end());
+          string expected_params;
+          write_function_for_name(out, obj, remaps, fname, expected_params, true, AT_keyword_args, RF_pyobject | RF_err_null);
+        }
+        break;
+
       case WT_none:
         // Nothing special about the wrapper function: just write it normally.
         string fname = "static PyObject *" + def._wrapper_name + "(PyObject *self, PyObject *args, PyObject *kwds)\n";
@@ -3229,7 +3257,7 @@ write_module_class(ostream &out, Object *obj) {
   // allocfunc tp_alloc;
   out << "    PyType_GenericAlloc,\n";
   // newfunc tp_new;
-  out << "    Dtool_new_" << ClassName << ",\n";
+  write_function_slot(out, 4, slots, "tp_new", "Dtool_new_" + ClassName);
   // freefunc tp_free;
   if (obj->_protocol_types & Object::PT_python_gc) {
     out << "    PyObject_GC_Del,\n";

+ 1 - 0
dtool/src/interrogate/interfaceMakerPythonNative.h

@@ -82,6 +82,7 @@ private:
     WT_traverse,
     WT_compare,
     WT_hash,
+    WT_new,
   };
 
   // This enum is passed to the wrapper generation functions to indicate what

+ 1 - 1
dtool/src/interrogatedb/dtool_super_base.cxx

@@ -43,7 +43,7 @@ static void *Dtool_UpcastInterface_DTOOL_SUPER_BASE(PyObject *self, Dtool_PyType
   return nullptr;
 }
 
-static Dtool_PyInstDef *Dtool_Wrap_DTOOL_SUPER_BASE(void *from_this, Dtool_PyTypedObject *from_type) {
+static PyObject *Dtool_Wrap_DTOOL_SUPER_BASE(void *from_this, PyTypeObject *from_type) {
   return nullptr;
 }
 

+ 20 - 12
dtool/src/interrogatedb/py_panda.I

@@ -122,30 +122,38 @@ INLINE long Dtool_EnumValue_AsLong(PyObject *value) {
  */
 template<class T> INLINE PyObject *
 DTool_CreatePyInstance(const T *obj, bool memory_rules) {
-  Dtool_PyTypedObject *known_class = (Dtool_PyTypedObject *)get_type_handle(T).get_python_type();
-  nassertr(known_class != nullptr, nullptr);
-  return DTool_CreatePyInstance((void*) obj, *known_class, memory_rules, true);
+  Dtool_PyInstDef *self = (Dtool_PyInstDef *)get_type_handle(T).wrap_python(obj);
+  nassertr(self != nullptr, nullptr);
+  self->_memory_rules = memory_rules;
+  self->_is_const = true;
+  return (PyObject *)self;
 }
 
 template<class T> INLINE PyObject *
 DTool_CreatePyInstance(T *obj, bool memory_rules) {
-  Dtool_PyTypedObject *known_class = (Dtool_PyTypedObject *)get_type_handle(T).get_python_type();
-  nassertr(known_class != nullptr, nullptr);
-  return DTool_CreatePyInstance((void*) obj, *known_class, memory_rules, false);
+  Dtool_PyInstDef *self = (Dtool_PyInstDef *)get_type_handle(T).wrap_python(obj);
+  nassertr(self != nullptr, nullptr);
+  self->_memory_rules = memory_rules;
+  self->_is_const = false;
+  return (PyObject *)self;
 }
 
 template<class T> INLINE PyObject *
 DTool_CreatePyInstanceTyped(const T *obj, bool memory_rules) {
-  Dtool_PyTypedObject *known_class = (Dtool_PyTypedObject *)get_type_handle(T).get_python_type();
-  nassertr(known_class != nullptr, nullptr);
-  return DTool_CreatePyInstanceTyped((void*) obj, *known_class, memory_rules, true, obj->get_type().get_index());
+  Dtool_PyInstDef *self = (Dtool_PyInstDef *)get_type_handle(T).wrap_python(obj);
+  nassertr(self != nullptr, nullptr);
+  self->_memory_rules = memory_rules;
+  self->_is_const = true;
+  return (PyObject *)self;
 }
 
 template<class T> INLINE PyObject *
 DTool_CreatePyInstanceTyped(T *obj, bool memory_rules) {
-  Dtool_PyTypedObject *known_class = (Dtool_PyTypedObject *)get_type_handle(T).get_python_type();
-  nassertr(known_class != nullptr, nullptr);
-  return DTool_CreatePyInstanceTyped((void*) obj, *known_class, memory_rules, false, obj->get_type().get_index());
+  Dtool_PyInstDef *self = (Dtool_PyInstDef *)get_type_handle(T).wrap_python(obj);
+  nassertr(self != nullptr, nullptr);
+  self->_memory_rules = memory_rules;
+  self->_is_const = false;
+  return (PyObject *)self;
 }
 
 /**

+ 5 - 9
dtool/src/interrogatedb/py_panda.cxx

@@ -454,15 +454,11 @@ PyObject *DTool_CreatePyInstanceTyped(void *local_this_in, Dtool_PyTypedObject &
   // IF the class is possibly a run time typed object
   if (type_index > 0) {
     // get best fit class...
-    Dtool_PyTypedObject *target_class = (Dtool_PyTypedObject *)TypeHandle::from_index(type_index).get_python_type();
-    if (target_class != nullptr) {
-      // cast to the type...
-      Dtool_PyInstDef *self = target_class->_Dtool_WrapInterface(local_this_in, &known_class_type);
-      if (self != nullptr) {
-        self->_memory_rules = memory_rules;
-        self->_is_const = is_const;
-        return (PyObject *)self;
-      }
+    Dtool_PyInstDef *self = (Dtool_PyInstDef *)TypeHandle::from_index(type_index).wrap_python(local_this_in, &known_class_type._PyType);
+    if (self != nullptr) {
+      self->_memory_rules = memory_rules;
+      self->_is_const = is_const;
+      return (PyObject *)self;
     }
   }
 

+ 1 - 1
dtool/src/interrogatedb/py_panda.h

@@ -41,7 +41,7 @@ struct Dtool_PyTypedObject;
 // used to stamp dtool instance..
 #define PY_PANDA_SIGNATURE 0xbeaf
 typedef void *(*UpcastFunction)(PyObject *,Dtool_PyTypedObject *);
-typedef Dtool_PyInstDef *(*WrapFunction)(void *, Dtool_PyTypedObject *);
+typedef PyObject *(*WrapFunction)(void *, PyTypeObject *);
 typedef void *(*CoerceFunction)(PyObject *, void *);
 typedef void (*ModuleClassInitFunction)(PyObject *module);
 

+ 4 - 0
panda/src/pgraph/bamFile.h

@@ -47,7 +47,11 @@ PUBLISHED:
   bool open_read(std::istream &in, const std::string &bam_filename = "stream",
                  bool report_errors = true);
 
+#if defined(CPPPARSER) && defined(HAVE_PYTHON)
+  EXTENSION(PyObject *read_object());
+#else
   TypedWritable *read_object();
+#endif
 
   bool is_eof() const;
   bool resolve();

+ 19 - 1
panda/src/pgraph/bamFile_ext.cxx

@@ -12,10 +12,28 @@
  */
 
 #include "bamFile_ext.h"
-#include "bamWriter_ext.h"
+#include "bamReader_ext.h"
 
 #ifdef HAVE_PYTHON
 
+#ifndef CPPPARSER
+extern Dtool_PyTypedObject Dtool_TypedWritable;
+#endif  // CPPPARSER
+
+/**
+ * Reads an object from the BamFile.
+ */
+PyObject *Extension<BamFile>::
+read_object() {
+  BamReader *reader = _this->get_reader();
+  if (reader == nullptr) {
+    PyErr_SetString(PyExc_ValueError, "BamFile not open for reading");
+    return nullptr;
+  }
+
+  return invoke_extension(reader).read_object();
+}
+
 /**
  * Returns the version number of the Bam file currently being written.
  */

+ 2 - 0
panda/src/pgraph/bamFile_ext.h

@@ -29,6 +29,8 @@
 template<>
 class Extension<BamFile> : public ExtensionBase<BamFile> {
 public:
+  PyObject *read_object();
+
   PyObject *get_file_version() const;
 };
 

+ 4 - 0
panda/src/putil/bamReader.h

@@ -132,8 +132,12 @@ PUBLISHED:
   INLINE const LoaderOptions &get_loader_options() const;
   INLINE void set_loader_options(const LoaderOptions &options);
 
+#if defined(CPPPARSER) && defined(HAVE_PYTHON)
+  EXTENSION(PyObject *read_object());
+#else
   BLOCKING TypedWritable *read_object();
   BLOCKING bool read_object(TypedWritable *&ptr, ReferenceCount *&ref_ptr);
+#endif
 
   INLINE bool is_eof() const;
   bool resolve();

+ 67 - 11
panda/src/putil/bamReader_ext.cxx

@@ -53,24 +53,20 @@ static TypedWritable *factory_callback(const FactoryParams &params){
 
   if (result == nullptr) {
     util_cat.error()
-      << "Exception occurred in Python factory function\n";
-
-  } else if (result == Py_None) {
+      << "Exception occurred in Python factory function:\n";
+    PyErr_Print();
+  }
+  else if (result == Py_None) {
     util_cat.error()
       << "Python factory function returned None\n";
     Py_DECREF(result);
     result = nullptr;
   }
 
-#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
-  PyGILState_Release(gstate);
-#endif
+  void *object = nullptr;
 
   // Unwrap the returned TypedWritable object.
-  if (result == nullptr) {
-    return nullptr;
-  } else {
-    void *object = nullptr;
+  if (result != nullptr) {
     Dtool_Call_ExtractThisPointer(result, Dtool_TypedWritable, &object);
 
     TypedWritable *ptr = (TypedWritable *)object;
@@ -88,9 +84,69 @@ static TypedWritable *factory_callback(const FactoryParams &params){
       }
       Py_DECREF(result);
     }
+    else if (DtoolInstance_TYPE(result) == &Dtool_TypedWritable &&
+             Py_TYPE(result) != &Dtool_TypedWritable._PyType) {
+      // It is a custom subclass of TypedWritable, so we have to keep it
+      // alive, and decrement it in finalize(), see typedWritable_ext.cxx.
+      manager->register_finalize(ptr);
+    }
+    else {
+      // Otherwise, we just decrement the Python reference count, but making
+      // sure that the C++ object is not getting deleted (yet) by this.
+      bool mem_rules = false;
+      std::swap(mem_rules, ((Dtool_PyInstDef *)result)->_memory_rules);
+      Py_DECREF(result);
+      std::swap(mem_rules, ((Dtool_PyInstDef *)result)->_memory_rules);
+    }
+  }
 
-    return (TypedWritable *)object;
+#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
+  PyGILState_Release(gstate);
+#endif
+
+  return (TypedWritable *)object;
+}
+
+/**
+ * Reads an object from the BamReader.
+ */
+PyObject *Extension<BamReader>::
+read_object() {
+  TypedWritable *ptr;
+  ReferenceCount *ref_ptr;
+
+#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
+  PyThreadState *_save;
+  Py_UNBLOCK_THREADS
+#endif
+
+  bool success = _this->read_object(ptr, ref_ptr);
+
+#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
+  Py_BLOCK_THREADS
+#endif
+
+  if (!success) {
+    if (_this->is_eof()) {
+      PyErr_SetNone(PyExc_EOFError);
+      return nullptr;
+    }
+    return nullptr;
+  }
+
+  if (ptr == nullptr) {
+    Py_INCREF(Py_None);
+    return Py_None;
+  }
+
+  if (ref_ptr != nullptr) {
+    ref_ptr->ref();
   }
+
+  // Note that, unlike the regular bindings, we take ownership of the object
+  // here even if it's not inheriting from ReferenceCount.
+  return DTool_CreatePyInstanceTyped((void *)ptr, Dtool_TypedWritable,
+                                     true, false, ptr->get_type_index());
 }
 
 /**

+ 2 - 0
panda/src/putil/bamReader_ext.h

@@ -29,6 +29,8 @@
 template<>
 class Extension<BamReader> : public ExtensionBase<BamReader> {
 public:
+  PyObject *read_object();
+
   PyObject *get_file_version() const;
 
   static void register_factory(TypeHandle handle, PyObject *func);

+ 4 - 1
panda/src/putil/typedWritable.h

@@ -33,6 +33,9 @@ class ReferenceCount;
  * See also TypedObject for detailed instructions.
  */
 class EXPCL_PANDA_PUTIL TypedWritable : public TypedObject {
+PUBLISHED:
+  EXTENSION(static PyObject *__new__(PyTypeObject *cls));
+
 public:
   static TypedWritable* const Null;
 
@@ -42,13 +45,13 @@ public:
 
   virtual ~TypedWritable();
 
-  virtual void write_datagram(BamWriter *manager, Datagram &dg);
   virtual void update_bam_nested(BamWriter *manager);
 
   virtual int complete_pointers(TypedWritable **p_list, BamReader *manager);
   virtual bool require_fully_complete() const;
 
 PUBLISHED:
+  virtual void write_datagram(BamWriter *manager, Datagram &dg);
   virtual void fillin(DatagramIterator &scan, BamReader *manager);
 
 public:

+ 205 - 0
panda/src/putil/typedWritable_ext.cxx

@@ -16,11 +16,216 @@
 #ifdef HAVE_PYTHON
 
 #include "bamWriter.h"
+#include "config_putil.h"
 
 #ifndef CPPPARSER
+extern Dtool_PyTypedObject Dtool_BamReader;
 extern Dtool_PyTypedObject Dtool_BamWriter;
+extern Dtool_PyTypedObject Dtool_Datagram;
+extern Dtool_PyTypedObject Dtool_DatagramIterator;
+extern Dtool_PyTypedObject Dtool_TypedObject;
+extern Dtool_PyTypedObject Dtool_TypedWritable;
+extern Dtool_PyTypedObject Dtool_TypeHandle;
 #endif  // CPPPARSER
 
+/**
+ * Class that upcalls to the parent class when write_datagram is called.
+ */
+class TypedWritableProxy : public TypedWritable {
+public:
+  ~TypedWritableProxy() {
+  }
+
+  virtual void write_datagram(BamWriter *manager, Datagram &dg) override {
+    // The derived method may call back to the TypedWritable implementation,
+    // which would end up back here.  Detect and prevent this.
+    thread_local TypedWritableProxy *recursion_protect = nullptr;
+    if (recursion_protect == this) {
+      TypedWritable::write_datagram(manager, dg);
+      return;
+    }
+
+    // We don't know where this might be invoked, so we have to be on the safe
+    // side and ensure that the GIL is being held.
+#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
+    PyGILState_STATE gstate = PyGILState_Ensure();
+#endif
+
+    TypedWritableProxy *prev_recursion_protect = this;
+    std::swap(recursion_protect, prev_recursion_protect);
+
+    PyObject *py_manager = DTool_CreatePyInstance(manager, Dtool_BamWriter, false, false);
+    PyObject *py_dg = DTool_CreatePyInstance(&dg, Dtool_Datagram, false, false);
+
+    PyObject *result = PyObject_CallMethod(_self, "write_datagram", "NN", py_manager, py_dg);
+    if (result != nullptr) {
+      Py_DECREF(result);
+    } else {
+      util_cat.error()
+        << "Exception occurred in Python write_datagram function:\n";
+      PyErr_Print();
+    }
+
+    std::swap(recursion_protect, prev_recursion_protect);
+
+#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
+    PyGILState_Release(gstate);
+#endif
+  }
+
+  virtual void fillin(DatagramIterator &scan, BamReader *manager) override {
+    // The derived method may call back to the TypedWritable implementation,
+    // which would end up back here.  Detect and prevent this.
+    thread_local TypedWritableProxy *recursion_protect = nullptr;
+    if (recursion_protect == this) {
+      TypedWritable::fillin(scan, manager);
+      return;
+    }
+
+    // We don't know where this might be invoked, so we have to be on the safe
+    // side and ensure that the GIL is being held.
+#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
+    PyGILState_STATE gstate = PyGILState_Ensure();
+#endif
+
+    TypedWritableProxy *prev_recursion_protect = this;
+    std::swap(recursion_protect, prev_recursion_protect);
+
+    PyObject *py_scan = DTool_CreatePyInstance(&scan, Dtool_DatagramIterator, false, false);
+    PyObject *py_manager = DTool_CreatePyInstance(manager, Dtool_BamReader, false, false);
+
+    PyObject *result = PyObject_CallMethod(_self, "fillin", "NN", py_scan, py_manager);
+    if (result != nullptr) {
+      Py_DECREF(result);
+    } else {
+      util_cat.error()
+        << "Exception occurred in Python fillin function:\n";
+      PyErr_Print();
+    }
+
+    std::swap(recursion_protect, prev_recursion_protect);
+
+#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
+    PyGILState_Release(gstate);
+#endif
+  }
+
+  virtual void finalize(BamReader *manager) override {
+    // If nobody stored the object after calling read_object(), this will cause
+    // this object to be deleted.
+    Py_DECREF(_self);
+  }
+
+  virtual TypeHandle get_type() const override {
+    return _type;
+  }
+
+  virtual TypeHandle force_init_type() override {
+    return _type;
+  }
+
+public:
+  PyObject *_self;
+  TypeHandle _type;
+};
+
+/**
+ * Returns a wrapper object for a TypedWritable subclass.
+ */
+static PyObject *
+wrap_typed_writable(void *from_this, PyTypeObject *from_type) {
+  nassertr(from_this != nullptr, nullptr);
+  nassertr(from_type != nullptr, nullptr);
+
+  TypedWritableProxy *to_this;
+  if (from_type == &Dtool_TypedWritable._PyType) {
+    to_this = (TypedWritableProxy *)(TypedWritable *)from_this;
+  }
+  else if (from_type == (PyTypeObject *)&Dtool_TypedObject._PyType) {
+    to_this = (TypedWritableProxy *)(TypedObject *)from_this;
+  }
+  else {
+    return nullptr;
+  }
+
+  nassertr(to_this->_self != nullptr, nullptr);
+  Py_INCREF(to_this->_self);
+  return to_this->_self;
+}
+
+/**
+ * Registers a Python type recursively, towards the TypedWritable base.
+ * Returns a TypeHandle if it inherited from TypedWritable, 0 otherwise.
+ */
+static TypeHandle
+register_python_type(TypeRegistry *registry, PyTypeObject *cls) {
+  TypeHandle handle = TypeHandle::none();
+
+  if (cls->tp_bases != nullptr) {
+    Py_ssize_t count = PyTuple_GET_SIZE(cls->tp_bases);
+    for (Py_ssize_t i = 0; i < count; ++i) {
+      PyObject *base = PyTuple_GET_ITEM(cls->tp_bases, count);
+      TypeHandle base_handle = register_python_type(registry, (PyTypeObject *)base);
+      if (base_handle != TypeHandle::none()) {
+        if (handle == TypeHandle::none()) {
+          handle = registry->register_dynamic_type(cls->tp_name);
+        }
+        registry->record_derivation(handle, base_handle);
+        return handle;
+      }
+    }
+  }
+
+  return handle;
+}
+
+/**
+ * This is called when a TypedWritable is instantiated directly.
+ */
+PyObject *Extension<TypedWritable>::
+__new__(PyTypeObject *cls) {
+  if (cls == (PyTypeObject *)&Dtool_TypedWritable) {
+    return Dtool_Raise_TypeError("cannot init abstract class");
+  }
+
+  PyObject *self = cls->tp_alloc(cls, 0);
+  ((Dtool_PyInstDef *)self)->_signature = PY_PANDA_SIGNATURE;
+  ((Dtool_PyInstDef *)self)->_My_Type = &Dtool_TypedWritable;
+
+  // We expect the user to override this method.
+  PyObject *class_type = PyObject_CallMethod((PyObject *)cls, "get_class_type", nullptr);
+  if (class_type == nullptr) {
+    return nullptr;
+  }
+
+  // Check that it returned a TypeHandle, and that it is actually different
+  // from the one on the base class (which might mean that the user didn't
+  // actually define a custom get_class_type() method).
+  TypeHandle *handle = nullptr;
+  if (!DtoolInstance_GetPointer(class_type, handle, Dtool_TypeHandle) ||
+      *handle == TypedWritable::get_class_type() ||
+      *handle == TypedObject::get_class_type() ||
+      !handle->is_derived_from(TypedWritable::get_class_type())) {
+    Dtool_Raise_TypeError("get_class_type() must be overridden to return a unique TypeHandle that indicates derivation from TypedWritable");
+    return nullptr;
+  }
+
+  // Make sure that the bindings know how to obtain a wrapper for this type.
+  TypeRegistry *registry = TypeRegistry::ptr();
+  registry->record_python_type(*handle, cls, &wrap_typed_writable);
+  Py_INCREF(cls);
+
+  // Note that we don't increment the reference count here, because that would
+  // create a memory leak.  The TypedWritableProxy gets deleted when the Python
+  // object reaches a reference count of 0.
+  TypedWritableProxy *proxy = new TypedWritableProxy;
+  proxy->_self = self;
+  proxy->_type = *handle;
+
+  DTool_PyInit_Finalize(self, (void *)proxy, &Dtool_TypedWritable, true, false);
+  return self;
+}
+
 /**
  * This special Python method is implement to provide support for the pickle
  * module.

+ 2 - 0
panda/src/putil/typedWritable_ext.h

@@ -29,6 +29,8 @@
 template<>
 class Extension<TypedWritable> : public ExtensionBase<TypedWritable> {
 public:
+  static PyObject *__new__(PyTypeObject *cls);
+
   PyObject *__reduce__(PyObject *self) const;
   PyObject *__reduce_persist__(PyObject *self, PyObject *pickler) const;
 

+ 53 - 0
tests/putil/test_custom_writable.py

@@ -0,0 +1,53 @@
+from panda3d.core import TypeRegistry, TypedWritable
+from panda3d.core import DatagramBuffer, BamReader, BamWriter
+import sys
+
+
+class CustomObject(TypedWritable):
+    def __init__(self):
+        self.field = 0
+
+    def get_class_type():
+        registry = TypeRegistry.ptr()
+        handle = registry.register_dynamic_type("CustomObject")
+        registry.record_derivation(handle, TypedWritable)
+        return handle
+
+    def write_datagram(self, writer, dg):
+        dg.add_uint8(self.field)
+
+    def fillin(self, scan, reader):
+        self.field = scan.get_uint8()
+
+    @staticmethod
+    def make_from_bam(scan, reader):
+        obj = CustomObject()
+        obj.fillin(scan, reader)
+        return obj
+
+
+BamReader.register_factory(CustomObject.get_class_type(), CustomObject.make_from_bam)
+
+
+def test_typed_writable_subclass():
+    obj = CustomObject()
+    obj.field = 123
+    assert obj.get_type() == CustomObject.get_class_type()
+    assert obj.type == CustomObject.get_class_type()
+
+    buf = DatagramBuffer()
+
+    writer = BamWriter(buf)
+    writer.init()
+    writer.write_object(obj)
+    del writer
+
+    reader = BamReader(buf)
+    reader.init()
+    obj = reader.read_object()
+    assert sys.getrefcount(obj) == 3
+    reader.resolve()
+    del reader
+    assert sys.getrefcount(obj) == 2
+
+    assert obj.field == 123