Ver código fonte

event: Allow AsyncFuture to store arbitrary PyObject result

rdb 4 anos atrás
pai
commit
064e0383be

+ 3 - 2
panda/src/event/asyncFuture.h

@@ -85,13 +85,14 @@ PUBLISHED:
   BLOCKING void wait();
   BLOCKING void wait();
   BLOCKING void wait(double timeout);
   BLOCKING void wait(double timeout);
 
 
+  EXTENSION(void set_result(PyObject *));
+
+public:
   INLINE void set_result(std::nullptr_t);
   INLINE void set_result(std::nullptr_t);
   INLINE void set_result(TypedObject *result);
   INLINE void set_result(TypedObject *result);
   INLINE void set_result(TypedReferenceCount *result);
   INLINE void set_result(TypedReferenceCount *result);
   INLINE void set_result(TypedWritableReferenceCount *result);
   INLINE void set_result(TypedWritableReferenceCount *result);
   INLINE void set_result(const EventParameter &result);
   INLINE void set_result(const EventParameter &result);
-
-public:
   void set_result(TypedObject *ptr, ReferenceCount *ref_ptr);
   void set_result(TypedObject *ptr, ReferenceCount *ref_ptr);
 
 
   INLINE TypedObject *get_result() const;
   INLINE TypedObject *get_result() const;

+ 79 - 1
panda/src/event/asyncFuture_ext.cxx

@@ -15,6 +15,7 @@
 #include "asyncTaskSequence.h"
 #include "asyncTaskSequence.h"
 #include "eventParameter.h"
 #include "eventParameter.h"
 #include "paramValue.h"
 #include "paramValue.h"
+#include "paramPyObject.h"
 #include "pythonTask.h"
 #include "pythonTask.h"
 #include "asyncTaskManager.h"
 #include "asyncTaskManager.h"
 #include "config_event.h"
 #include "config_event.h"
@@ -23,8 +24,11 @@
 
 
 #ifndef CPPPARSER
 #ifndef CPPPARSER
 extern struct Dtool_PyTypedObject Dtool_AsyncFuture;
 extern struct Dtool_PyTypedObject Dtool_AsyncFuture;
+extern struct Dtool_PyTypedObject Dtool_EventParameter;
 extern struct Dtool_PyTypedObject Dtool_ParamValueBase;
 extern struct Dtool_PyTypedObject Dtool_ParamValueBase;
 extern struct Dtool_PyTypedObject Dtool_TypedObject;
 extern struct Dtool_PyTypedObject Dtool_TypedObject;
+extern struct Dtool_PyTypedObject Dtool_TypedReferenceCount;
+extern struct Dtool_PyTypedObject Dtool_TypedWritableReferenceCount;
 #endif
 #endif
 
 
 /**
 /**
@@ -92,9 +96,13 @@ static PyObject *get_done_result(const AsyncFuture *future) {
         // EventStoreInt and Double are not exposed to Python for some reason.
         // EventStoreInt and Double are not exposed to Python for some reason.
         if (type == EventStoreInt::get_class_type()) {
         if (type == EventStoreInt::get_class_type()) {
           return Dtool_WrapValue(((EventStoreInt *)ptr)->get_value());
           return Dtool_WrapValue(((EventStoreInt *)ptr)->get_value());
-        } else if (type == EventStoreDouble::get_class_type()) {
+        }
+        else if (type == EventStoreDouble::get_class_type()) {
           return Dtool_WrapValue(((EventStoreDouble *)ptr)->get_value());
           return Dtool_WrapValue(((EventStoreDouble *)ptr)->get_value());
         }
         }
+        else if (type == ParamPyObject::get_class_type()) {
+          return ((ParamPyObject *)ptr)->get_value();
+        }
 
 
         ParamValueBase *value = (ParamValueBase *)ptr;
         ParamValueBase *value = (ParamValueBase *)ptr;
         PyObject *wrap = DTool_CreatePyInstanceTyped
         PyObject *wrap = DTool_CreatePyInstanceTyped
@@ -176,6 +184,76 @@ __await__(PyObject *self) {
   return Dtool_NewGenerator(self, &gen_next);
   return Dtool_NewGenerator(self, &gen_next);
 }
 }
 
 
+/**
+ * Sets this future's result.  Can only be called if done() returns false.
+ */
+void Extension<AsyncFuture>::
+set_result(PyObject *result) {
+  if (result == Py_None) {
+    _this->set_result(nullptr);
+    return;
+  }
+  else if (DtoolInstance_Check(result)) {
+    void *ptr;
+    if ((ptr = DtoolInstance_UPCAST(result, Dtool_EventParameter))) {
+      _this->set_result(*(const EventParameter *)ptr);
+      return;
+    }
+    if ((ptr = DtoolInstance_UPCAST(result, Dtool_TypedWritableReferenceCount))) {
+      _this->set_result((TypedWritableReferenceCount *)ptr);
+      return;
+    }
+    if ((ptr = DtoolInstance_UPCAST(result, Dtool_TypedReferenceCount))) {
+      _this->set_result((TypedReferenceCount *)ptr);
+      return;
+    }
+    if ((ptr = DtoolInstance_UPCAST(result, Dtool_TypedObject))) {
+      _this->set_result((TypedObject *)ptr);
+      return;
+    }
+  }
+  else if (PyUnicode_Check(result)) {
+#if PY_VERSION_HEX >= 0x03030000
+    Py_ssize_t result_len;
+    wchar_t *result_str = PyUnicode_AsWideCharString(result, &result_len);
+#else
+    Py_ssize_t result_len = PyUnicode_GET_SIZE(result);
+    wchar_t *result_str = (wchar_t *)alloca(sizeof(wchar_t) * (result_len + 1));
+    PyUnicode_AsWideChar((PyUnicodeObject *)result, result_str, result_len);
+#endif
+    _this->set_result(new EventStoreWstring(std::wstring(result_str, result_len)));
+#if PY_VERSION_HEX >= 0x03030000
+    PyMem_Free(result_str);
+#endif
+    return;
+  }
+#if PY_MAJOR_VERSION < 3
+  else if (PyString_Check(result)) {
+    const char *result_str;
+    Py_ssize_t result_len;
+    if (PyString_AsStringAndSize(result, (char **)&result_str, &result_len) != -1) {
+      _this->set_result(new EventStoreString(std::string(result_str, result_len)));
+    }
+    return;
+  }
+#endif
+  else if (PyLongOrInt_Check(result)) {
+    long result_val = PyLongOrInt_AS_LONG(result);
+    if (result_val >= INT_MIN && result_val <= INT_MAX) {
+      _this->set_result(new EventStoreInt((int)result_val));
+      return;
+    }
+  }
+  else if (PyNumber_Check(result)) {
+    _this->set_result(new EventStoreDouble(PyFloat_AsDouble(result)));
+    return;
+  }
+
+  // If we don't recognize the type, store it as a generic PyObject pointer.
+  ParamPyObject::init_type();
+  _this->set_result(new ParamPyObject(result));
+}
+
 /**
 /**
  * Returns the result of this future, unless it was cancelled, in which case
  * Returns the result of this future, unless it was cancelled, in which case
  * it returns CancelledError.
  * it returns CancelledError.

+ 1 - 0
panda/src/event/asyncFuture_ext.h

@@ -29,6 +29,7 @@ public:
   static PyObject *__await__(PyObject *self);
   static PyObject *__await__(PyObject *self);
   static PyObject *__iter__(PyObject *self) { return __await__(self); }
   static PyObject *__iter__(PyObject *self) { return __await__(self); }
 
 
+  void set_result(PyObject *result);
   PyObject *result(PyObject *timeout = Py_None) const;
   PyObject *result(PyObject *timeout = Py_None) const;
 
 
   PyObject *add_done_callback(PyObject *self, PyObject *fn);
   PyObject *add_done_callback(PyObject *self, PyObject *fn);

+ 1 - 0
panda/src/putil/p3putil_ext_composite.cxx

@@ -1,5 +1,6 @@
 #include "bamReader_ext.cxx"
 #include "bamReader_ext.cxx"
 #include "bitArray_ext.cxx"
 #include "bitArray_ext.cxx"
+#include "paramPyObject.cxx"
 #include "pythonCallbackObject.cxx"
 #include "pythonCallbackObject.cxx"
 #include "sparseArray_ext.cxx"
 #include "sparseArray_ext.cxx"
 #include "typedWritable_ext.cxx"
 #include "typedWritable_ext.cxx"

+ 31 - 0
panda/src/putil/paramPyObject.I

@@ -0,0 +1,31 @@
+/**
+ * PANDA 3D SOFTWARE
+ * Copyright (c) Carnegie Mellon University.  All rights reserved.
+ *
+ * All use of this software is subject to the terms of the revised BSD
+ * license.  You should have received a copy of this license along
+ * with this source code in a file named "LICENSE."
+ *
+ * @file paramPyObject.I
+ * @author rdb
+ * @date 2021-03-01
+ */
+
+#include "paramPyObject.h"
+
+/**
+ * Increments the reference count.  Assumes the GIL is held.
+ */
+INLINE ParamPyObject::
+ParamPyObject(PyObject *value) : _value(value) {
+  Py_INCREF(value);
+}
+
+/**
+ * Returns a new reference to the stored value.
+ */
+INLINE PyObject *ParamPyObject::
+get_value() const {
+  Py_INCREF(_value);
+  return _value;
+}

+ 42 - 0
panda/src/putil/paramPyObject.cxx

@@ -0,0 +1,42 @@
+/**
+ * PANDA 3D SOFTWARE
+ * Copyright (c) Carnegie Mellon University.  All rights reserved.
+ *
+ * All use of this software is subject to the terms of the revised BSD
+ * license.  You should have received a copy of this license along
+ * with this source code in a file named "LICENSE."
+ *
+ * @file paramPyObject.cxx
+ * @author rdb
+ * @date 2021-03-01
+ */
+
+#include "paramPyObject.h"
+
+TypeHandle ParamPyObject::_type_handle;
+
+/**
+ * Decrements the reference count.
+ */
+ParamPyObject::
+~ParamPyObject() {
+#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
+  PyGILState_STATE gstate;
+  gstate = PyGILState_Ensure();
+#endif
+
+  Py_DECREF(_value);
+
+#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
+  PyGILState_Release(gstate);
+#endif
+}
+
+/**
+ *
+ */
+void ParamPyObject::
+output(std::ostream &out) const {
+  out << "<" << Py_TYPE(_value)->tp_name
+      << " object at " << (void *)_value << ">";
+}

+ 61 - 0
panda/src/putil/paramPyObject.h

@@ -0,0 +1,61 @@
+/**
+ * PANDA 3D SOFTWARE
+ * Copyright (c) Carnegie Mellon University.  All rights reserved.
+ *
+ * All use of this software is subject to the terms of the revised BSD
+ * license.  You should have received a copy of this license along
+ * with this source code in a file named "LICENSE."
+ *
+ * @file paramPyObject.h
+ * @author rdb
+ * @date 2021-03-01
+ */
+
+#ifndef PARAMPYOBJECT_H
+#define PARAMPYOBJECT_H
+
+#include "pandabase.h"
+#include "paramValue.h"
+
+#ifdef HAVE_PYTHON
+
+#include "py_panda.h"
+
+/**
+ * A class object for storing an arbitrary Python object.
+ */
+class ParamPyObject final : public ParamValueBase {
+public:
+  INLINE ParamPyObject(PyObject *value);
+  virtual ~ParamPyObject();
+
+  INLINE PyObject *get_value() const;
+
+  void output(std::ostream &out) const override;
+
+public:
+  PyObject *_value;
+
+public:
+  virtual TypeHandle get_type() const override {
+    return get_class_type();
+  }
+  virtual TypeHandle force_init_type() override {init_type(); return get_class_type();}
+  static TypeHandle get_class_type() {
+    return _type_handle;
+  }
+  static void init_type() {
+    ParamValueBase::init_type();
+    register_type(_type_handle, "ParamPyObject",
+                  ParamValueBase::get_class_type());
+  }
+
+private:
+  static TypeHandle _type_handle;
+};
+
+#include "paramPyObject.I"
+
+#endif  // HAVE_PYTHON
+
+#endif

+ 65 - 0
tests/event/test_futures.py

@@ -164,6 +164,71 @@ def test_coro_exception():
         task.result()
         task.result()
 
 
 
 
+def test_future_result():
+    # Cancelled
+    fut = core.AsyncFuture()
+    assert not fut.done()
+    fut.cancel()
+    with pytest.raises(Exception):
+        fut.result()
+
+    # None
+    fut = core.AsyncFuture()
+    fut.set_result(None)
+    assert fut.done()
+    assert fut.result() is None
+
+    # Store int
+    fut = core.AsyncFuture()
+    fut.set_result(123)
+    assert fut.result() == 123
+
+    # Store string
+    fut = core.AsyncFuture()
+    fut.set_result("test\000\u1234")
+    assert fut.result() == "test\000\u1234"
+
+    # Store TypedWritableReferenceCount
+    tex = core.Texture()
+    rc = tex.get_ref_count()
+    fut = core.AsyncFuture()
+    fut.set_result(tex)
+    assert tex.get_ref_count() == rc + 1
+    assert fut.result() == tex
+    assert tex.get_ref_count() == rc + 1
+    assert fut.result() == tex
+    assert tex.get_ref_count() == rc + 1
+    fut = None
+    assert tex.get_ref_count() == rc
+
+    # Store EventParameter (gets unwrapped)
+    ep = core.EventParameter(0.5)
+    fut = core.AsyncFuture()
+    fut.set_result(ep)
+    assert fut.result() == 0.5
+    assert fut.result() == 0.5
+
+    # Store TypedObject
+    dg = core.Datagram(b"test")
+    fut = core.AsyncFuture()
+    fut.set_result(dg)
+    assert fut.result() == dg
+    assert fut.result() == dg
+
+    # Store arbitrary Python object
+    obj = object()
+    rc = sys.getrefcount(obj)
+    fut = core.AsyncFuture()
+    fut.set_result(obj)
+    assert sys.getrefcount(obj) == rc + 1
+    assert fut.result() is obj
+    assert sys.getrefcount(obj) == rc + 1
+    assert fut.result() is obj
+    assert sys.getrefcount(obj) == rc + 1
+    fut = None
+    assert sys.getrefcount(obj) == rc
+
+
 def test_future_gather():
 def test_future_gather():
     fut1 = core.AsyncFuture()
     fut1 = core.AsyncFuture()
     fut2 = core.AsyncFuture()
     fut2 = core.AsyncFuture()