Jelajahi Sumber

Make use of PyImport_GetModule function

rdb 1 tahun lalu
induk
melakukan
4b2fa45578

+ 14 - 0
dtool/src/interrogatedb/py_compat.h

@@ -189,6 +189,20 @@ INLINE PyObject *_PyObject_FastCall(PyObject *func, PyObject **args, Py_ssize_t
   } while (0)
 #endif
 
+#if PY_VERSION_HEX < 0x03070000
+INLINE PyObject *PyImport_GetModule(PyObject *name) {
+  PyObject *modules = PyImport_GetModuleDict();
+  if (modules != nullptr) {
+    PyObject *module = PyDict_GetItem(modules, name);
+    if (module != nullptr) {
+      Py_INCREF(module);
+      return module;
+    }
+  }
+  return nullptr;
+}
+#endif
+
 /* Python 3.8 */
 #if PY_VERSION_HEX < 0x03080000
 INLINE PyObject *_PyLong_Rshift(PyObject *a, size_t shiftby) {

+ 22 - 16
dtool/src/interrogatedb/py_wrappers.cxx

@@ -17,27 +17,33 @@
 #endif
 
 static void _register_collection(PyTypeObject *type, const char *abc) {
-  PyObject *sys_modules = PyImport_GetModuleDict();
-  if (sys_modules != nullptr) {
-    PyObject *module = PyDict_GetItemString(sys_modules, _COLLECTIONS_ABC);
-    if (module != nullptr) {
-      PyObject *dict = PyModule_GetDict(module);
-      if (module != nullptr) {
 #if PY_MAJOR_VERSION >= 3
-        PyObject *register_str = PyUnicode_InternFromString("register");
+  PyObject *module_name = PyUnicode_InternFromString(_COLLECTIONS_ABC);
 #else
-        PyObject *register_str = PyString_InternFromString("register");
+  PyObject *module_name = PyString_InternFromString(_COLLECTIONS_ABC);
 #endif
-        PyObject *obj = nullptr;
-        if (register_str == nullptr ||
-            PyDict_GetItemStringRef(dict, abc, &obj) <= 0 ||
-            PyObject_CallMethodOneArg(obj, register_str, (PyObject *)type) == nullptr) {
-          PyErr_Print();
-        }
-        Py_XDECREF(obj);
-        Py_XDECREF(register_str);
+  PyObject *module = PyImport_GetModule(module_name);
+  Py_DECREF(module_name);
+  if (module != nullptr) {
+    PyObject *dict = PyModule_GetDict(module);
+    if (dict != nullptr) {
+#if PY_MAJOR_VERSION >= 3
+      PyObject *register_str = PyUnicode_InternFromString("register");
+#else
+      PyObject *register_str = PyString_InternFromString("register");
+#endif
+      PyObject *obj = nullptr;
+      if (register_str == nullptr ||
+          PyDict_GetItemStringRef(dict, abc, &obj) <= 0 ||
+          PyObject_CallMethodOneArg(obj, register_str, (PyObject *)type) == nullptr) {
+        PyErr_Print();
       }
+      Py_XDECREF(obj);
+      Py_XDECREF(register_str);
+    } else {
+      PyErr_Clear();
     }
+    Py_DECREF(module);
   }
 }
 

+ 2 - 8
panda/src/egg/eggNode_ext.cxx

@@ -24,18 +24,12 @@ __reduce__() const {
   extern struct Dtool_PyTypedObject Dtool_EggNode;
 
   // Find the parse_egg_node function in this module.
-  PyObject *sys_modules = PyImport_GetModuleDict();
-  nassertr_always(sys_modules != nullptr, nullptr);
-
   PyObject *module_name = PyObject_GetAttrString((PyObject *)&Dtool_EggNode, "__module__");
   nassertr_always(module_name != nullptr, nullptr);
 
-  PyObject *module;
-  int res = PyDict_GetItemRef(sys_modules, module_name, &module);
+  PyObject *module = PyImport_GetModule(module_name);
   Py_DECREF(module_name);
-  if (res <= 0) {
-    return nullptr;
-  }
+  nassertr_always(module != nullptr, nullptr);
 
   PyObject *func;
   if (_this->is_of_type(EggData::get_class_type())) {

+ 7 - 12
panda/src/putil/typedWritable_ext.cxx

@@ -336,20 +336,15 @@ find_global_decode(PyObject *this_class, const char *func_name) {
   // Get the module in which BamWriter is defined.
   PyObject *module_name = PyObject_GetAttrString((PyObject *)&Dtool_BamWriter, "__module__");
   if (module_name != nullptr) {
-    // borrowed reference
-    PyObject *sys_modules = PyImport_GetModuleDict();
-    if (sys_modules != nullptr) {
-      // borrowed reference
-      PyObject *module = PyDict_GetItem(sys_modules, module_name);
-      if (module != nullptr) {
-        PyObject *func = PyObject_GetAttrString(module, (char *)func_name);
-        if (func != nullptr) {
-          Py_DECREF(module_name);
-          return func;
-        }
+    PyObject *module = PyImport_GetModule(module_name);
+    Py_DECREF(module_name);
+    if (module != nullptr) {
+      PyObject *func = PyObject_GetAttrString(module, (char *)func_name);
+      Py_DECREF(module);
+      if (func != nullptr) {
+        return func;
       }
     }
-    Py_DECREF(module_name);
   }
 
   PyObject *bases = PyObject_GetAttrString(this_class, "__bases__");