Browse Source

task: Add AsyncFuture::shield() ability, part of #1136

This is modelled after `asyncio.shield()` and can be used to protect an inner future from cancellation when the outer future is cancelled.
rdb 4 years ago
parent
commit
cdf5b16ddd

+ 1 - 0
direct/src/task/Task.py

@@ -91,6 +91,7 @@ pause = AsyncTaskPause
 Task.DtoolClassDict['pause'] = staticmethod(pause)
 Task.DtoolClassDict['pause'] = staticmethod(pause)
 
 
 gather = Task.gather
 gather = Task.gather
+shield = Task.shield
 
 
 def sequence(*taskList):
 def sequence(*taskList):
     seq = AsyncTaskSequence('sequence')
     seq = AsyncTaskSequence('sequence')

+ 18 - 0
panda/src/event/asyncFuture.I

@@ -133,6 +133,24 @@ gather(Futures futures) {
   }
   }
 }
 }
 
 
+/**
+ * Creates a new future that shields the given future from cancellation.
+ * Calling `cancel()` on the returned future will not affect the given future.
+ */
+INLINE PT(AsyncFuture) AsyncFuture::
+shield(PT(AsyncFuture) future) {
+  if (future->try_lock_pending()) {
+    PT(AsyncFuture) outer = new AsyncFuture;
+    outer->_manager = future->_manager;
+    future->_waiting.push_back((AsyncFuture *)outer);
+    future->unlock();
+    return outer;
+  }
+  else {
+    return future;
+  }
+}
+
 /**
 /**
  * Tries to atomically lock the future, assuming it is pending.  Returns false
  * Tries to atomically lock the future, assuming it is pending.  Returns false
  * if it is not in the pending state, implying it's either done or about to be
  * if it is not in the pending state, implying it's either done or about to be

+ 21 - 4
panda/src/event/asyncFuture.cxx

@@ -148,13 +148,13 @@ notify_done(bool clean_exit) {
   // This will only be called by the thread that managed to set the
   // This will only be called by the thread that managed to set the
   // _future_state away from the "pending" state, so this is thread safe.
   // _future_state away from the "pending" state, so this is thread safe.
 
 
-  Futures::iterator it;
-  for (it = _waiting.begin(); it != _waiting.end(); ++it) {
-    AsyncFuture *fut = *it;
+  // Go through the futures that are waiting for this to finish.
+  for (AsyncFuture *fut : _waiting) {
     if (fut->is_task()) {
     if (fut->is_task()) {
       // It's a task.  Make it active again.
       // It's a task.  Make it active again.
       wake_task((AsyncTask *)fut);
       wake_task((AsyncTask *)fut);
-    } else {
+    }
+    else if (fut->get_type() == AsyncGatheringFuture::get_class_type()) {
       // It's a gathering future.  Decrease the pending count on it, and if
       // It's a gathering future.  Decrease the pending count on it, and if
       // we're the last one, call notify_done() on it.
       // we're the last one, call notify_done() on it.
       AsyncGatheringFuture *gather = (AsyncGatheringFuture *)fut;
       AsyncGatheringFuture *gather = (AsyncGatheringFuture *)fut;
@@ -164,6 +164,23 @@ notify_done(bool clean_exit) {
         }
         }
       }
       }
     }
     }
+    else {
+      // It's a shielding future.  The shielding only protects the inner future
+      // when the outer is cancelled, not the other way around, so we have to
+      // propagate any cancellation here as well.
+      if (clean_exit && _result != nullptr) {
+        // Propagate the result, if any.
+        if (fut->try_lock_pending()) {
+          fut->_result = _result;
+          fut->_result_ref = _result_ref;
+          fut->unlock(FS_finished);
+          fut->notify_done(true);
+        }
+      }
+      else if (fut->set_future_state(clean_exit ? FS_finished : FS_cancelled)) {
+        fut->notify_done(clean_exit);
+      }
+    }
   }
   }
   _waiting.clear();
   _waiting.clear();
 
 

+ 1 - 0
panda/src/event/asyncFuture.h

@@ -78,6 +78,7 @@ PUBLISHED:
   EXTENSION(PyObject *add_done_callback(PyObject *self, PyObject *fn));
   EXTENSION(PyObject *add_done_callback(PyObject *self, PyObject *fn));
 
 
   EXTENSION(static PyObject *gather(PyObject *args));
   EXTENSION(static PyObject *gather(PyObject *args));
+  INLINE static PT(AsyncFuture) shield(PT(AsyncFuture) future);
 
 
   virtual void output(std::ostream &out) const;
   virtual void output(std::ostream &out) const;
 
 

+ 60 - 0
tests/event/test_futures.py

@@ -289,6 +289,66 @@ def test_future_gather_cancel_outer():
         assert gather.result()
         assert gather.result()
 
 
 
 
+def test_future_shield():
+    # An already done future is returned as-is (no cancellation can occur)
+    inner = core.AsyncFuture()
+    inner.set_result(None)
+    outer = core.AsyncFuture.shield(inner)
+    assert inner == outer
+
+    # Normally finishing future
+    inner = core.AsyncFuture()
+    outer = core.AsyncFuture.shield(inner)
+    assert not outer.done()
+    inner.set_result(None)
+    assert outer.done()
+    assert not outer.cancelled()
+    assert inner.result() is None
+
+    # Normally finishing future with result
+    inner = core.AsyncFuture()
+    outer = core.AsyncFuture.shield(inner)
+    assert not outer.done()
+    inner.set_result(123)
+    assert outer.done()
+    assert not outer.cancelled()
+    assert inner.result() == 123
+
+    # Cancelled inner future does propagate cancellation outward
+    inner = core.AsyncFuture()
+    outer = core.AsyncFuture.shield(inner)
+    assert not outer.done()
+    inner.cancel()
+    assert outer.done()
+    assert outer.cancelled()
+
+    # Finished outer future does nothing to inner
+    inner = core.AsyncFuture()
+    outer = core.AsyncFuture.shield(inner)
+    outer.set_result(None)
+    assert not inner.done()
+    inner.cancel()
+    assert not outer.cancelled()
+
+    # Cancelled outer future does nothing to inner
+    inner = core.AsyncFuture()
+    outer = core.AsyncFuture.shield(inner)
+    outer.cancel()
+    assert not inner.done()
+    inner.cancel()
+
+    # Can be shielded multiple times
+    inner = core.AsyncFuture()
+    outer1 = core.AsyncFuture.shield(inner)
+    outer2 = core.AsyncFuture.shield(inner)
+    outer1.cancel()
+    assert not inner.done()
+    assert not outer2.done()
+    inner.cancel()
+    assert outer1.done()
+    assert outer2.done()
+
+
 def test_future_done_callback():
 def test_future_done_callback():
     fut = core.AsyncFuture()
     fut = core.AsyncFuture()