2
0
Эх сурвалжийг харах

event: cancel future being awaited when cancelling coroutine task

Fixes #1136
rdb 4 жил өмнө
parent
commit
5ef1b44455

+ 1 - 1
panda/src/event/asyncFuture.h

@@ -67,7 +67,7 @@ PUBLISHED:
 
 
   INLINE bool done() const;
   INLINE bool done() const;
   INLINE bool cancelled() const;
   INLINE bool cancelled() const;
-  EXTENSION(PyObject *result(PyObject *timeout = Py_None) const);
+  EXTENSION(PyObject *result(PyObject *self, PyObject *timeout = Py_None) const);
 
 
   virtual bool cancel();
   virtual bool cancel();
 
 

+ 28 - 9
panda/src/event/asyncFuture_ext.cxx

@@ -222,16 +222,35 @@ set_result(PyObject *result) {
  * raises TimeoutError.
  * raises TimeoutError.
  */
  */
 PyObject *Extension<AsyncFuture>::
 PyObject *Extension<AsyncFuture>::
-result(PyObject *timeout) const {
+result(PyObject *self, PyObject *timeout) const {
+  double timeout_val;
+  if (timeout != Py_None) {
+    timeout_val = PyFloat_AsDouble(timeout);
+    if (timeout_val == -1.0 && _PyErr_OCCURRED()) {
+      return nullptr;
+    }
+  }
+
   if (!_this->done()) {
   if (!_this->done()) {
     // Not yet done?  Wait until it is done, or until a timeout occurs.  But
     // Not yet done?  Wait until it is done, or until a timeout occurs.  But
     // first check to make sure we're not trying to deadlock the thread.
     // first check to make sure we're not trying to deadlock the thread.
     Thread *current_thread = Thread::get_current_thread();
     Thread *current_thread = Thread::get_current_thread();
-    if (_this == (const AsyncFuture *)current_thread->get_current_task()) {
+    AsyncTask *current_task = (AsyncTask *)current_thread->get_current_task();
+    if (_this == current_task) {
       PyErr_SetString(PyExc_RuntimeError, "cannot call task.result() from within the task");
       PyErr_SetString(PyExc_RuntimeError, "cannot call task.result() from within the task");
       return nullptr;
       return nullptr;
     }
     }
 
 
+    PythonTask *python_task = nullptr;
+    if (current_task != nullptr &&
+        current_task->is_of_type(PythonTask::get_class_type())) {
+      // If we are calling result() inside a coroutine, mark it as awaiting this
+      // future.  That makes it possible to cancel() us from another thread.
+      python_task = (PythonTask *)current_task;
+      nassertr(python_task->_fut_waiter == nullptr, nullptr);
+      python_task->_fut_waiter = self;
+    }
+
     // Release the GIL for the duration.
     // Release the GIL for the duration.
 #if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
 #if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
     PyThreadState *_save;
     PyThreadState *_save;
@@ -239,18 +258,18 @@ result(PyObject *timeout) const {
 #endif
 #endif
     if (timeout == Py_None) {
     if (timeout == Py_None) {
       _this->wait();
       _this->wait();
-    } else {
-      PyObject *num = PyNumber_Float(timeout);
-      if (num != nullptr) {
-        _this->wait(PyFloat_AS_DOUBLE(num));
-      } else {
-        return Dtool_Raise_ArgTypeError(timeout, 0, "result", "float");
-      }
+    }
+    else {
+      _this->wait(timeout_val);
     }
     }
 #if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
 #if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
     Py_BLOCK_THREADS
     Py_BLOCK_THREADS
 #endif
 #endif
 
 
+    if (python_task != nullptr) {
+      python_task->_fut_waiter = nullptr;
+    }
+
     if (!_this->done()) {
     if (!_this->done()) {
       // It timed out.  Raise an exception.
       // It timed out.  Raise an exception.
       static PyObject *exc_type = nullptr;
       static PyObject *exc_type = nullptr;

+ 1 - 1
panda/src/event/asyncFuture_ext.h

@@ -30,7 +30,7 @@ public:
   static PyObject *__iter__(PyObject *self) { return __await__(self); }
   static PyObject *__iter__(PyObject *self) { return __await__(self); }
 
 
   void set_result(PyObject *result);
   void set_result(PyObject *result);
-  PyObject *result(PyObject *timeout = Py_None) const;
+  PyObject *result(PyObject *self, PyObject *timeout = Py_None) const;
 
 
   PyObject *add_done_callback(PyObject *self, PyObject *fn);
   PyObject *add_done_callback(PyObject *self, PyObject *fn);
 
 

+ 72 - 23
panda/src/event/pythonTask.cxx

@@ -45,7 +45,7 @@ PythonTask(PyObject *func_or_coro, const std::string &name) :
   _exc_value(nullptr),
   _exc_value(nullptr),
   _exc_traceback(nullptr),
   _exc_traceback(nullptr),
   _generator(nullptr),
   _generator(nullptr),
-  _future_done(nullptr),
+  _fut_waiter(nullptr),
   _ignore_return(false),
   _ignore_return(false),
   _retrieved_exception(false) {
   _retrieved_exception(false) {
 
 
@@ -404,20 +404,58 @@ cancel() {
         << "Cancelling " << *this << "\n";
         << "Cancelling " << *this << "\n";
     }
     }
 
 
+    bool must_cancel = true;
+    if (_fut_waiter != nullptr) {
+      // Cancel the future that this task is waiting on.  Note that we do this
+      // before grabbing the lock, since this operation may also grab it.  This
+      // means that _fut_waiter is only protected by the GIL.
+#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
+      // Use PyGILState to protect this asynchronous call.
+      PyGILState_STATE gstate;
+      gstate = PyGILState_Ensure();
+#endif
+
+      // Shortcut for unextended AsyncFuture.
+      if (Py_TYPE(_fut_waiter) == (PyTypeObject *)&Dtool_AsyncFuture) {
+        AsyncFuture *fut = (AsyncFuture *)DtoolInstance_VOID_PTR(_fut_waiter);
+        if (!fut->done()) {
+          fut->cancel();
+        }
+        if (fut->done()) {
+          // We don't need this anymore.
+          Py_DECREF(_fut_waiter);
+          _fut_waiter = nullptr;
+        }
+      }
+      else {
+        PyObject *result = PyObject_CallMethod(_fut_waiter, "cancel", nullptr);
+        Py_XDECREF(result);
+      }
+
+#if defined(HAVE_THREADS) && !defined(SIMPLE_THREADS)
+      PyGILState_Release(gstate);
+#endif
+      // Keep _fut_waiter in any case, because we may need to cancel it again
+      // later if it ignores the cancellation.
+    }
+
     MutexHolder holder(manager->_lock);
     MutexHolder holder(manager->_lock);
     if (_state == S_awaiting) {
     if (_state == S_awaiting) {
       // Reactivate it so that it can receive a CancelledException.
       // Reactivate it so that it can receive a CancelledException.
-      _must_cancel = true;
+      if (must_cancel) {
+        _must_cancel = true;
+      }
       _state = AsyncTask::S_active;
       _state = AsyncTask::S_active;
       _chain->_active.push_back(this);
       _chain->_active.push_back(this);
       --_chain->_num_awaiting_tasks;
       --_chain->_num_awaiting_tasks;
       return true;
       return true;
     }
     }
-    else if (_future_done != nullptr) {
-      // We are polling, waiting for a non-Panda future to be done.
-      Py_DECREF(_future_done);
-      _future_done = nullptr;
-      _must_cancel = true;
+    else if (must_cancel || _fut_waiter != nullptr) {
+      // We may be polling an external future, so we still need to throw a
+      // CancelledException and allow it to be caught.
+      if (must_cancel) {
+        _must_cancel = true;
+      }
       return true;
       return true;
     }
     }
     else if (_chain->do_remove(this, true)) {
     else if (_chain->do_remove(this, true)) {
@@ -477,17 +515,24 @@ AsyncTask::DoneStatus PythonTask::
 do_python_task() {
 do_python_task() {
   PyObject *result = nullptr;
   PyObject *result = nullptr;
 
 
-  // Are we waiting for a future to finish?
-  if (_future_done != nullptr) {
-    PyObject *is_done = PyObject_CallNoArgs(_future_done);
-    if (!PyObject_IsTrue(is_done)) {
-      // Nope, ask again next frame.
+  // Are we waiting for a future to finish?  Short-circuit all the logic below
+  // by simply calling done().
+  {
+    PyObject *fut_waiter = _fut_waiter;
+    if (fut_waiter != nullptr) {
+      PyObject *is_done = PyObject_CallMethod(fut_waiter, "done", nullptr);
+      if (is_done == nullptr) {
+        return DS_interrupt;
+      }
+      if (!PyObject_IsTrue(is_done)) {
+        // Nope, ask again next frame.
+        Py_DECREF(is_done);
+        return DS_cont;
+      }
       Py_DECREF(is_done);
       Py_DECREF(is_done);
-      return DS_cont;
+      Py_DECREF(fut_waiter);
+      _fut_waiter = nullptr;
     }
     }
-    Py_DECREF(is_done);
-    Py_DECREF(_future_done);
-    _future_done = nullptr;
   }
   }
 
 
   if (_generator == nullptr) {
   if (_generator == nullptr) {
@@ -664,7 +709,9 @@ do_python_task() {
           task_cat.error()
           task_cat.error()
             << *this << " cannot await itself\n";
             << *this << " cannot await itself\n";
         }
         }
-        Py_DECREF(result);
+        // Store the Python object in case we need to cancel it (it may be a
+        // subclass of AsyncFuture that overrides cancel() from Python)
+        _fut_waiter = result;
         return DS_await;
         return DS_await;
       }
       }
     } else {
     } else {
@@ -674,8 +721,9 @@ do_python_task() {
       if (check != nullptr && check != Py_None) {
       if (check != nullptr && check != Py_None) {
         Py_DECREF(check);
         Py_DECREF(check);
         // Next frame, check whether this future is done.
         // Next frame, check whether this future is done.
-        _future_done = PyObject_GetAttrString(result, "done");
-        if (_future_done == nullptr || !PyCallable_Check(_future_done)) {
+        PyObject *fut_done = PyObject_GetAttrString(result, "done");
+        if (fut_done == nullptr || !PyCallable_Check(fut_done)) {
+          Py_XDECREF(fut_done);
           task_cat.error()
           task_cat.error()
             << "future.done is not callable\n";
             << "future.done is not callable\n";
           return DS_interrupt;
           return DS_interrupt;
@@ -686,7 +734,7 @@ do_python_task() {
             << *this << " is now polling " << PyUnicode_AsUTF8(str) << ".done()\n";
             << *this << " is now polling " << PyUnicode_AsUTF8(str) << ".done()\n";
           Py_DECREF(str);
           Py_DECREF(str);
         }
         }
-        Py_DECREF(result);
+        _fut_waiter = result;
         return DS_cont;
         return DS_cont;
       }
       }
       PyErr_Clear();
       PyErr_Clear();
@@ -802,9 +850,10 @@ upon_death(AsyncTaskManager *manager, bool clean_exit) {
   AsyncTask::upon_death(manager, clean_exit);
   AsyncTask::upon_death(manager, clean_exit);
 
 
   // If we were polling something when we were removed, get rid of it.
   // If we were polling something when we were removed, get rid of it.
-  if (_future_done != nullptr) {
-    Py_DECREF(_future_done);
-    _future_done = nullptr;
+  //TODO: should we call cancel() on it?
+  if (_fut_waiter != nullptr) {
+    Py_DECREF(_fut_waiter);
+    _fut_waiter = nullptr;
   }
   }
 
 
   if (_upon_death != Py_None) {
   if (_upon_death != Py_None) {

+ 1 - 1
panda/src/event/pythonTask.h

@@ -115,7 +115,7 @@ private:
   PyObject *_exc_traceback;
   PyObject *_exc_traceback;
 
 
   PyObject *_generator;
   PyObject *_generator;
-  PyObject *_future_done;
+  PyObject *_fut_waiter;
 
 
   bool _append_task;
   bool _append_task;
   bool _ignore_return;
   bool _ignore_return;

+ 221 - 0
tests/event/test_futures.py

@@ -9,6 +9,33 @@ else:
     from concurrent.futures._base import TimeoutError, CancelledError
     from concurrent.futures._base import TimeoutError, CancelledError
 
 
 
 
+class MockFuture:
+    _asyncio_future_blocking = False
+    _state = 'PENDING'
+    _cancel_return = False
+    _result = None
+
+    def __await__(self):
+        while self._state == 'PENDING':
+            yield self
+        return self.result()
+
+    def done(self):
+        return self._state != 'PENDING'
+
+    def cancelled(self):
+        return self._state == 'CANCELLED'
+
+    def cancel(self):
+        return self._cancel_return
+
+    def result(self):
+        if self._state == 'CANCELLED':
+            raise CancelledError
+
+        return self._result
+
+
 def test_future_cancelled():
 def test_future_cancelled():
     fut = core.AsyncFuture()
     fut = core.AsyncFuture()
 
 
@@ -123,6 +150,66 @@ def test_task_cancel_during_run():
         task.result()
         task.result()
 
 
 
 
+def test_task_cancel_waiting():
+    # Calling result() in a threaded task chain should cancel the future being
+    # waited on if the surrounding task is cancelled.
+    task_mgr = core.AsyncTaskManager.get_global_ptr()
+    task_chain = task_mgr.make_task_chain("test_task_cancel_waiting")
+    task_chain.set_num_threads(1)
+
+    fut = core.AsyncFuture()
+
+    async def task_main(task):
+        # This will block the thread this task is in until the future is done,
+        # or until the task is cancelled (which implicitly cancels the future).
+        fut.result()
+        return task.done
+
+    task = core.PythonTask(task_main, 'task_main')
+    task.set_task_chain(task_chain.name)
+    task_mgr.add(task)
+
+    task_chain.start_threads()
+    try:
+        assert not task.done()
+        fut.cancel()
+        task.wait()
+
+        assert task.cancelled()
+        assert fut.cancelled()
+
+    finally:
+        task_chain.stop_threads()
+
+
+def test_task_cancel_awaiting():
+    task_mgr = core.AsyncTaskManager.get_global_ptr()
+    task_chain = task_mgr.make_task_chain("test_task_cancel_awaiting")
+
+    fut = core.AsyncFuture()
+
+    async def task_main(task):
+        await fut
+        return task.done
+
+    task = core.PythonTask(task_main, 'task_main')
+    task.set_task_chain(task_chain.name)
+    task_mgr.add(task)
+
+    task_chain.poll()
+    assert not task.done()
+
+    task_chain.poll()
+    assert not task.done()
+
+    task.cancel()
+    task_chain.poll()
+    assert task.done()
+    assert task.cancelled()
+    assert fut.done()
+    assert fut.cancelled()
+
+
 def test_task_result():
 def test_task_result():
     task_mgr = core.AsyncTaskManager.get_global_ptr()
     task_mgr = core.AsyncTaskManager.get_global_ptr()
     task_chain = task_mgr.make_task_chain("test_task_result")
     task_chain = task_mgr.make_task_chain("test_task_result")
@@ -144,6 +231,140 @@ def test_task_result():
     assert task.result() == 42
     assert task.result() == 42
 
 
 
 
+def test_coro_await_coro():
+    # Await another coro in a coro.
+    fut = core.AsyncFuture()
+    async def coro2():
+        await fut
+
+    async def coro_main():
+        await coro2()
+
+    task = core.PythonTask(coro_main())
+
+    task_mgr = core.AsyncTaskManager.get_global_ptr()
+    task_mgr.add(task)
+    for i in range(5):
+        task_mgr.poll()
+
+    assert not task.done()
+    fut.set_result(None)
+    task_mgr.poll()
+    assert task.done()
+    assert not task.cancelled()
+
+
+def test_coro_await_cancel_resistant_coro():
+    # Await another coro in a coro, but cancel the outer.
+    fut = core.AsyncFuture()
+    cancelled_caught = [0]
+    keep_going = [False]
+
+    async def cancel_resistant_coro():
+        while not fut.done():
+            try:
+                await core.AsyncFuture.shield(fut)
+            except CancelledError as ex:
+                cancelled_caught[0] += 1
+
+    async def coro_main():
+        await cancel_resistant_coro()
+
+    task = core.PythonTask(coro_main(), 'coro_main')
+
+    task_mgr = core.AsyncTaskManager.get_global_ptr()
+    task_mgr.add(task)
+    assert not task.done()
+
+    task_mgr.poll()
+    assert not task.done()
+
+    # No cancelling it once it started...
+    for i in range(3):
+        assert task.cancel()
+        assert not task.done()
+
+        for j in range(3):
+            task_mgr.poll()
+            assert not task.done()
+
+    assert cancelled_caught[0] == 3
+
+    fut.set_result(None)
+    task_mgr.poll()
+    assert task.done()
+    assert not task.cancelled()
+
+
+def test_coro_await_external():
+    # Await an external future in a coro.
+    fut = MockFuture()
+    fut._result = 12345
+    res = []
+
+    async def coro_main():
+        res.append(await fut)
+
+    task = core.PythonTask(coro_main(), 'coro_main')
+
+    task_mgr = core.AsyncTaskManager.get_global_ptr()
+    task_mgr.add(task)
+    for i in range(5):
+        task_mgr.poll()
+
+    assert not task.done()
+    fut._state = 'FINISHED'
+    task_mgr.poll()
+    assert task.done()
+    assert not task.cancelled()
+    assert res == [12345]
+
+
+def test_coro_await_external_cancel_inner():
+    # Cancel external future being awaited by a coro.
+    fut = MockFuture()
+
+    async def coro_main():
+        await fut
+
+    task = core.PythonTask(coro_main(), 'coro_main')
+
+    task_mgr = core.AsyncTaskManager.get_global_ptr()
+    task_mgr.add(task)
+    for i in range(5):
+        task_mgr.poll()
+
+    assert not task.done()
+    fut._state = 'CANCELLED'
+    assert not task.done()
+    task_mgr.poll()
+    assert task.done()
+    assert task.cancelled()
+
+
+def test_coro_await_external_cancel_outer():
+    # Cancel task that is awaiting external future.
+    fut = MockFuture()
+    result = []
+
+    async def coro_main():
+        result.append(await fut)
+
+    task = core.PythonTask(coro_main(), 'coro_main')
+
+    task_mgr = core.AsyncTaskManager.get_global_ptr()
+    task_mgr.add(task)
+    for i in range(5):
+        task_mgr.poll()
+
+    assert not task.done()
+    fut._state = 'CANCELLED'
+    assert not task.done()
+    task_mgr.poll()
+    assert task.done()
+    assert task.cancelled()
+
+
 def test_coro_exception():
 def test_coro_exception():
     task_mgr = core.AsyncTaskManager.get_global_ptr()
     task_mgr = core.AsyncTaskManager.get_global_ptr()
     task_chain = task_mgr.make_task_chain("test_coro_exception")
     task_chain = task_mgr.make_task_chain("test_coro_exception")