Browse Source

event: Fix `await gather()` returning first item instead of tuple

Fixes #1738
rdb 4 months ago
parent
commit
41e4cf5d11
2 changed files with 46 additions and 7 deletions
  1. 21 2
      panda/src/event/asyncFuture_ext.cxx
  2. 25 5
      tests/event/test_futures.py

+ 21 - 2
panda/src/event/asyncFuture_ext.cxx

@@ -167,8 +167,27 @@ static PyObject *gen_next(PyObject *self) {
   } else {
   } else {
     PyObject *result = get_done_result(future);
     PyObject *result = get_done_result(future);
     if (result != nullptr) {
     if (result != nullptr) {
-      Py_INCREF(PyExc_StopIteration);
-      PyErr_Restore(PyExc_StopIteration, result, nullptr);
+      // See python/cpython#101578 - PyErr_SetObject has a special case where
+      // it interprets a tuple specially, so we bypass that by creating the
+      // exception directly.
+#if PY_VERSION_HEX >= 0x030C0000 // 3.12
+      PyObject *exc = PyObject_CallOneArg(PyExc_StopIteration, result);
+      if (LIKELY(exc != nullptr)) {
+        // This function steals a reference to exc.
+        PyErr_SetRaisedException(exc);
+      }
+#else
+      if (PyTuple_Check(result)) {
+        PyObject *exc = PyObject_CallOneArg(PyExc_StopIteration, result);
+        if (LIKELY(exc != nullptr)) {
+          PyErr_SetObject(PyExc_StopIteration, exc);
+          Py_DECREF(exc);
+        }
+      } else {
+        Py_INCREF(PyExc_StopIteration);
+        PyErr_Restore(PyExc_StopIteration, result, nullptr);
+      }
+#endif
     }
     }
     return nullptr;
     return nullptr;
   }
   }

+ 25 - 5
tests/event/test_futures.py

@@ -10,6 +10,21 @@ else:
     CancelledError = Exception
     CancelledError = Exception
 
 
 
 
+def check_result(fut, expected):
+    """Asserts the result of the future is the expected value."""
+
+    if fut.result() != expected:
+        return False
+
+    # Make sure that await also returns the values properly
+    with pytest.raises(StopIteration) as e:
+        next(fut.__await__())
+    if e.value.value != expected:
+        return False
+
+    return True
+
+
 def test_future_cancelled():
 def test_future_cancelled():
     fut = core.AsyncFuture()
     fut = core.AsyncFuture()
 
 
@@ -205,15 +220,20 @@ def test_future_result():
     ep = core.EventParameter(0.5)
     ep = core.EventParameter(0.5)
     fut = core.AsyncFuture()
     fut = core.AsyncFuture()
     fut.set_result(ep)
     fut.set_result(ep)
-    assert fut.result() == 0.5
-    assert fut.result() == 0.5
+    assert check_result(fut, 0.5)
+    assert check_result(fut, 0.5)
 
 
     # Store TypedObject
     # Store TypedObject
     dg = core.Datagram(b"test")
     dg = core.Datagram(b"test")
     fut = core.AsyncFuture()
     fut = core.AsyncFuture()
     fut.set_result(dg)
     fut.set_result(dg)
-    assert fut.result() == dg
-    assert fut.result() == dg
+    assert check_result(fut, dg)
+    assert check_result(fut, dg)
+
+    # Store tuple
+    fut = core.AsyncFuture()
+    fut.set_result((1, 2))
+    assert check_result(fut, (1, 2))
 
 
     # Store arbitrary Python object
     # Store arbitrary Python object
     obj = object()
     obj = object()
@@ -250,7 +270,7 @@ def test_future_gather():
     assert gather.done()
     assert gather.done()
 
 
     assert not gather.cancelled()
     assert not gather.cancelled()
-    assert tuple(gather.result()) == (1, 2)
+    assert check_result(gather, (1, 2))
 
 
 
 
 def test_future_gather_cancel_inner():
 def test_future_gather_cancel_inner():