Browse Source

fsm: Support asynchronous transitions via coroutine enter/exit funcs

Fixes #1037
rdb 4 years ago
parent
commit
9cb3c7726f
1 changed files with 85 additions and 23 deletions
  1. 85 23
      direct/src/fsm/FSM.py

+ 85 - 23
direct/src/fsm/FSM.py

@@ -13,6 +13,8 @@ from direct.showbase.MessengerGlobal import messenger
 from direct.showbase import PythonUtil
 from direct.directnotify import DirectNotifyGlobal
 from direct.stdpy.threading import RLock
+from panda3d.core import AsyncTaskManager, AsyncFuture, PythonTask
+import types
 
 
 class FSMException(Exception):
@@ -27,6 +29,19 @@ class RequestDenied(FSMException):
     pass
 
 
+class Transition(tuple):
+    """Used for the return value of fsm.request().  Behaves like a tuple, for
+    historical reasons."""
+
+    _future = None
+
+    def __await__(self):
+        if self._future:
+            yield self._future
+
+        return tuple(self)
+
+
 class FSM(DirectObject):
     """
     A Finite State Machine.  This is intended to be the base class
@@ -154,6 +169,9 @@ class FSM(DirectObject):
     # must be approved by some filter function.
     defaultTransitions = None
 
+    __doneFuture = AsyncFuture()
+    __doneFuture.set_result(None)
+
     # An enum class for special states like the DEFAULT or ANY state,
     # that should be treatened by the FSM in a special way
     class EnumStates():
@@ -247,7 +265,13 @@ class FSM(DirectObject):
     def forceTransition(self, request, *args):
         """Changes unconditionally to the indicated state.  This
         bypasses the filterState() function, and just calls
-        exitState() followed by enterState()."""
+        exitState() followed by enterState().
+
+        If the FSM is currently undergoing a transition, this will
+        queue up the new transition.
+
+        Returns a future, which can be used to await the transition.
+        """
 
         self.fsmLock.acquire()
         try:
@@ -257,11 +281,13 @@ class FSM(DirectObject):
 
             if not self.state:
                 # Queue up the request.
-                self.__requestQueue.append(PythonUtil.Functor(
-                    self.forceTransition, request, *args))
-                return
+                fut = AsyncFuture()
+                self.__requestQueue.append((PythonUtil.Functor(
+                    self.forceTransition, request, *args), fut))
+                return fut
 
-            self.__setState(request, *args)
+            result = self.__setState(request, *args)
+            return result._future or self.__doneFuture
         finally:
             self.fsmLock.release()
 
@@ -275,6 +301,10 @@ class FSM(DirectObject):
         request is queued up and will be executed when the current
         transition finishes.  Multiple requests will queue up in
         sequence.
+
+        The return value of this function can be used in an `await`
+        expression to suspend the current coroutine until the
+        transition is done.
         """
 
         self.fsmLock.acquire()
@@ -284,12 +314,15 @@ class FSM(DirectObject):
                 self._name, request, str(args)[1:]))
             if not self.state:
                 # Queue up the request.
-                self.__requestQueue.append(PythonUtil.Functor(
-                    self.demand, request, *args))
-                return
+                fut = AsyncFuture()
+                self.__requestQueue.append((PythonUtil.Functor(
+                    self.demand, request, *args), fut))
+                return fut
 
-            if not self.request(request, *args):
+            result = self.request(request, *args)
+            if not result:
                 raise RequestDenied("%s (from state: %s)" % (request, self.state))
+            return result._future or self.__doneFuture
         finally:
             self.fsmLock.release()
 
@@ -314,7 +347,12 @@ class FSM(DirectObject):
         executing an enterState or exitState function), an
         `AlreadyInTransition` exception is raised (but see `demand()`,
         which will queue these requests up and apply when the
-        transition is complete)."""
+        transition is complete).
+
+        If the previous state's exitFunc or the new state's enterFunc
+        is a coroutine, the state change may not have been applied by
+        the time request() returns, but you can use `await` on the
+        return value to await the transition."""
 
         self.fsmLock.acquire()
         try:
@@ -331,7 +369,7 @@ class FSM(DirectObject):
                     result = (result,) + args
 
                 # Otherwise, assume it's a (name, *args) tuple
-                self.__setState(*result)
+                return self.__setState(*result)
 
             return result
         finally:
@@ -441,11 +479,11 @@ class FSM(DirectObject):
         try:
             if self.stateArray:
                 if not self.state in self.stateArray:
-                    self.request(self.stateArray[0])
+                    return self.request(self.stateArray[0])
                 else:
                     cur_index = self.stateArray.index(self.state)
                     new_index = (cur_index + 1) % len(self.stateArray)
-                    self.request(self.stateArray[new_index], args)
+                    return self.request(self.stateArray[new_index], args)
             else:
                 assert self.notifier.debug(
                                     "stateArray empty. Can't switch to next.")
@@ -459,11 +497,11 @@ class FSM(DirectObject):
         try:
             if self.stateArray:
                 if not self.state in self.stateArray:
-                    self.request(self.stateArray[0])
+                    return self.request(self.stateArray[0])
                 else:
                     cur_index = self.stateArray.index(self.state)
                     new_index = (cur_index - 1) % len(self.stateArray)
-                    self.request(self.stateArray[new_index], args)
+                    return self.request(self.stateArray[new_index], args)
             else:
                 assert self.notifier.debug(
                                     "stateArray empty. Can't switch to next.")
@@ -471,8 +509,26 @@ class FSM(DirectObject):
             self.fsmLock.release()
 
     def __setState(self, newState, *args):
-        # Internal function to change unconditionally to the indicated
-        # state.
+        # Internal function to change unconditionally to the indicated state.
+
+        transition = Transition((newState,) + args)
+
+        # See if we can transition immediately by polling the coroutine.
+        coro = self.__transition(newState, *args)
+        try:
+            coro.send(None)
+        except StopIteration:
+            # We managed to apply this straight away.
+            return transition
+
+        # Continue the state transition in a task.
+        task = PythonTask(coro)
+        mgr = AsyncTaskManager.get_global_ptr()
+        mgr.add(task)
+        transition._future = task
+        return transition
+
+    async def __transition(self, newState, *args):
         assert self.state
         assert self.notify.debug("%s to state %s." % (self._name, newState))
 
@@ -482,8 +538,13 @@ class FSM(DirectObject):
 
         try:
             if not self.__callFromToFunc(self.oldState, self.newState, *args):
-                self.__callExitFunc(self.oldState)
-                self.__callEnterFunc(self.newState, *args)
+                result = self.__callExitFunc(self.oldState)
+                if isinstance(result, types.CoroutineType):
+                    await result
+
+                result = self.__callEnterFunc(self.newState, *args)
+                if isinstance(result, types.CoroutineType):
+                    await result
         except:
             # If we got an exception during the enter or exit methods,
             # go directly to state "InternalError" and raise up the
@@ -503,9 +564,10 @@ class FSM(DirectObject):
         del self.newState
 
         if self.__requestQueue:
-            request = self.__requestQueue.pop(0)
+            request, fut = self.__requestQueue.pop(0)
             assert self.notify.debug("%s continued queued request." % (self._name))
-            request()
+            await request()
+            fut.set_result(None)
 
     def __callEnterFunc(self, name, *args):
         # Calls the appropriate enter function when transitioning into
@@ -517,7 +579,7 @@ class FSM(DirectObject):
             # If there's no matching enterFoo() function, call
             # defaultEnter() instead.
             func = self.defaultEnter
-        func(*args)
+        return func(*args)
 
     def __callFromToFunc(self, oldState, newState, *args):
         # Calls the appropriate fromTo function when transitioning into
@@ -540,7 +602,7 @@ class FSM(DirectObject):
             # If there's no matching exitFoo() function, call
             # defaultExit() instead.
             func = self.defaultExit
-        func()
+        return func()
 
     def __repr__(self):
         return self.__str__()