2
0
Эх сурвалжийг харах

task: Support calling cancel() on currently awaiting futures

Fixes #911
rdb 5 жил өмнө
parent
commit
bfbbcad990

+ 34 - 16
panda/src/event/asyncFuture_ext.cxx

@@ -117,22 +117,7 @@ static PyObject *get_done_result(const AsyncFuture *future) {
     }
   } else {
     // If the future was cancelled, we should raise an exception.
-    static PyObject *exc_type = nullptr;
-    if (exc_type == nullptr) {
-      // Get the CancelledError that asyncio uses, too.
-      PyObject *module = PyImport_ImportModule("concurrent.futures._base");
-      if (module != nullptr) {
-        exc_type = PyObject_GetAttrString(module, "CancelledError");
-        Py_DECREF(module);
-      }
-      // If we can't get that, we should pretend and make our own.
-      if (exc_type == nullptr) {
-        exc_type = PyErr_NewExceptionWithDoc((char*)"concurrent.futures._base.CancelledError",
-                                             (char*)"The Future was cancelled.",
-                                             nullptr, nullptr);
-      }
-    }
-    PyErr_SetNone(exc_type);
+    PyErr_SetNone(Extension<AsyncFuture>::get_cancelled_error_type());
     return nullptr;
   }
 }
@@ -303,4 +288,37 @@ gather(PyObject *args) {
   }
 }
 
+/**
+ * Returns a borrowed reference to the CancelledError exception type.
+ */
+PyObject *Extension<AsyncFuture>::
+get_cancelled_error_type() {
+  static PyObject *exc_type = nullptr;
+  if (exc_type == nullptr) {
+    // Get the CancelledError that asyncio uses, too.
+#if PY_VERSION_HEX >= 0x03080000
+    PyObject *module = PyImport_ImportModule("asyncio.exceptions");
+#else
+    PyObject *module = PyImport_ImportModule("concurrent.futures._base");
+#endif
+    if (module != nullptr) {
+      exc_type = PyObject_GetAttrString(module, "CancelledError");
+      Py_DECREF(module);
+    }
+    // If we can't get that, we should pretend and make our own.
+    if (exc_type == nullptr) {
+#if PY_VERSION_HEX >= 0x03080000
+      exc_type = PyErr_NewExceptionWithDoc((char *)"asyncio.exceptions.CancelledError",
+                                            (char *)"The Future or Task was cancelled.",
+                                            PyExc_BaseException, nullptr);
+#else
+      exc_type = PyErr_NewExceptionWithDoc((char *)"concurrent.futures._base.CancelledError",
+                                            (char *)"The Future was cancelled.",
+                                            nullptr, nullptr);
+#endif
+    }
+  }
+  return exc_type;
+}
+
 #endif

+ 2 - 0
panda/src/event/asyncFuture_ext.h

@@ -34,6 +34,8 @@ public:
   PyObject *add_done_callback(PyObject *self, PyObject *fn);
 
   static PyObject *gather(PyObject *args);
+
+  static PyObject *get_cancelled_error_type();
 };
 
 #endif  // HAVE_PYTHON

+ 5 - 1
panda/src/event/asyncTask.cxx

@@ -68,6 +68,9 @@ AsyncTask::
  * Removes the task from its active manager, if any, and makes the state
  * S_inactive (or possible S_servicing_removed).  This is a no-op if the state
  * is already S_inactive.
+ *
+ * If the task is a coroutine that is currently awaiting a future, this will
+ * fail, but see also cancel().
  */
 bool AsyncTask::
 remove() {
@@ -457,7 +460,8 @@ unlock_and_do_task() {
 }
 
 /**
- * Cancels this task.  This is equivalent to remove().
+ * Cancels this task.  This is equivalent to remove(), except for coroutines,
+ * for which it will throw an exception into any currently pending await.
  */
 bool AsyncTask::
 cancel() {

+ 1 - 1
panda/src/event/asyncTask.h

@@ -124,7 +124,7 @@ protected:
   void jump_to_task_chain(AsyncTaskManager *manager);
   DoneStatus unlock_and_do_task();
 
-  virtual bool cancel() final;
+  virtual bool cancel();
   virtual bool is_task() const final {return true;}
 
   virtual bool is_runnable();

+ 1 - 0
panda/src/event/asyncTaskChain.h

@@ -218,6 +218,7 @@ private:
   friend class AsyncTask;
   friend class AsyncTaskManager;
   friend class AsyncTaskSortWakeTime;
+  friend class PythonTask;
 };
 
 INLINE std::ostream &operator << (std::ostream &out, const AsyncTaskChain &chain) {

+ 80 - 6
panda/src/event/pythonTask.cxx

@@ -20,6 +20,7 @@
 
 #include "pythonThread.h"
 #include "asyncTaskManager.h"
+#include "asyncFuture_ext.h"
 
 TypeHandle PythonTask::_type_handle;
 
@@ -391,6 +392,51 @@ __clear__() {
   return 0;
 }
 
+/**
+ * Cancels this task.  This is equivalent to remove(), except for coroutines,
+ * for which it will throw an exception into any currently pending await.
+ */
+bool PythonTask::
+cancel() {
+  AsyncTaskManager *manager = _manager;
+  if (manager != nullptr) {
+    nassertr(_chain->_manager == manager, false);
+    if (task_cat.is_debug()) {
+      task_cat.debug()
+        << "Cancelling " << *this << "\n";
+    }
+
+    MutexHolder holder(manager->_lock);
+    if (_state == S_awaiting) {
+      // Reactivate it so that it can receive a CancelledException.
+      _must_cancel = true;
+      _state = AsyncTask::S_active;
+      _chain->_active.push_back(this);
+      --_chain->_num_awaiting_tasks;
+      return true;
+    }
+    else if (_future_done != nullptr) {
+      // We are polling, waiting for a non-Panda future to be done.
+      Py_DECREF(_future_done);
+      _future_done = nullptr;
+      _must_cancel = true;
+      return true;
+    }
+    else if (_chain->do_remove(this, true)) {
+      return true;
+    }
+    else {
+      if (task_cat.is_debug()) {
+        task_cat.debug()
+          << "  (unable to cancel " << *this << ")\n";
+      }
+      return false;
+    }
+  }
+
+  return false;
+}
+
 /**
  * Override this function to return true if the task can be successfully
  * executed, false if it cannot.  Mainly intended as a sanity check when
@@ -492,12 +538,22 @@ do_python_task() {
   }
 
   if (_generator != nullptr) {
-    // We are calling a generator.  Use "send" rather than PyIter_Next since
-    // we need to be able to read the value from a StopIteration exception.
-    PyObject *func = PyObject_GetAttrString(_generator, "send");
-    nassertr(func != nullptr, DS_interrupt);
-    result = PyObject_CallFunctionObjArgs(func, Py_None, nullptr);
-    Py_DECREF(func);
+    if (!_must_cancel) {
+      // We are calling a generator.  Use "send" rather than PyIter_Next since
+      // we need to be able to read the value from a StopIteration exception.
+      PyObject *func = PyObject_GetAttrString(_generator, "send");
+      nassertr(func != nullptr, DS_interrupt);
+      result = PyObject_CallFunctionObjArgs(func, Py_None, nullptr);
+      Py_DECREF(func);
+    } else {
+      // Throw a CancelledError into the generator.
+      _must_cancel = false;
+      PyObject *exc = _PyObject_CallNoArg(Extension<AsyncFuture>::get_cancelled_error_type());
+      PyObject *func = PyObject_GetAttrString(_generator, "throw");
+      result = PyObject_CallFunctionObjArgs(func, exc, nullptr);
+      Py_DECREF(func);
+      Py_DECREF(exc);
+    }
 
     if (result == nullptr) {
       // An error happened.  If StopIteration, that indicates the task has
@@ -509,6 +565,12 @@ do_python_task() {
       if (_PyGen_FetchStopIterationValue(&result) == 0) {
         PyErr_Clear();
 
+        if (_must_cancel) {
+          // Task was cancelled right before finishing.  Make sure it is not
+          // getting rerun or marked as successfully completed.
+          _state = S_servicing_removed;
+        }
+
         // If we passed a coroutine into the task, eg. something like:
         //   taskMgr.add(my_async_function())
         // then we cannot rerun the task, so the return value is always
@@ -524,6 +586,18 @@ do_python_task() {
           _exc_value = result;
           return DS_done;
         }
+
+      } else if (PyErr_ExceptionMatches(Extension<AsyncFuture>::get_cancelled_error_type())) {
+        // Someone cancelled the coroutine, and it did not bother to handle it,
+        // so we should consider it cancelled.
+        if (task_cat.is_debug()) {
+          task_cat.debug()
+            << *this << " was cancelled and did not catch CancelledError.\n";
+        }
+        _state = S_servicing_removed;
+        PyErr_Clear();
+        return DS_done;
+
       } else if (_function == nullptr) {
         // We got an exception.  If this is a scheduled coroutine, we will
         // keep it and instead throw it into whatever 'awaits' this task.

+ 3 - 0
panda/src/event/pythonTask.h

@@ -90,6 +90,8 @@ PUBLISHED:
   PyObject *__dict__;
 
 protected:
+  virtual bool cancel();
+
   virtual bool is_runnable();
   virtual DoneStatus do_task();
   DoneStatus do_python_task();
@@ -119,6 +121,7 @@ private:
   bool _ignore_return;
   bool _registered_to_owner;
   mutable bool _retrieved_exception;
+  bool _must_cancel = false;
 
   friend class Extension<AsyncFuture>;
 

+ 6 - 1
tests/event/test_futures.py

@@ -1,7 +1,12 @@
 from panda3d import core
 import pytest
 import time
-from concurrent.futures._base import TimeoutError, CancelledError
+import sys
+
+if sys.version_info >= (3, 8):
+    from asyncio.exceptions import TimeoutError, CancelledError
+else:
+    from concurrent.futures._base import TimeoutError, CancelledError
 
 
 def test_future_cancelled():