Browse Source

fsm: Support asynchronous transitions via coroutine enter/exit funcs

Fixes #1037
rdb 4 years ago
parent
commit
9cb3c7726f
1 changed files with 85 additions and 23 deletions
  1. 85 23
      direct/src/fsm/FSM.py

+ 85 - 23
direct/src/fsm/FSM.py

@@ -13,6 +13,8 @@ from direct.showbase.MessengerGlobal import messenger
 from direct.showbase import PythonUtil
 from direct.showbase import PythonUtil
 from direct.directnotify import DirectNotifyGlobal
 from direct.directnotify import DirectNotifyGlobal
 from direct.stdpy.threading import RLock
 from direct.stdpy.threading import RLock
+from panda3d.core import AsyncTaskManager, AsyncFuture, PythonTask
+import types
 
 
 
 
 class FSMException(Exception):
 class FSMException(Exception):
@@ -27,6 +29,19 @@ class RequestDenied(FSMException):
     pass
     pass
 
 
 
 
+class Transition(tuple):
+    """Used for the return value of fsm.request().  Behaves like a tuple, for
+    historical reasons."""
+
+    _future = None
+
+    def __await__(self):
+        if self._future:
+            yield self._future
+
+        return tuple(self)
+
+
 class FSM(DirectObject):
 class FSM(DirectObject):
     """
     """
     A Finite State Machine.  This is intended to be the base class
     A Finite State Machine.  This is intended to be the base class
@@ -154,6 +169,9 @@ class FSM(DirectObject):
     # must be approved by some filter function.
     # must be approved by some filter function.
     defaultTransitions = None
     defaultTransitions = None
 
 
+    __doneFuture = AsyncFuture()
+    __doneFuture.set_result(None)
+
     # An enum class for special states like the DEFAULT or ANY state,
     # An enum class for special states like the DEFAULT or ANY state,
     # that should be treatened by the FSM in a special way
     # that should be treatened by the FSM in a special way
     class EnumStates():
     class EnumStates():
@@ -247,7 +265,13 @@ class FSM(DirectObject):
     def forceTransition(self, request, *args):
     def forceTransition(self, request, *args):
         """Changes unconditionally to the indicated state.  This
         """Changes unconditionally to the indicated state.  This
         bypasses the filterState() function, and just calls
         bypasses the filterState() function, and just calls
-        exitState() followed by enterState()."""
+        exitState() followed by enterState().
+
+        If the FSM is currently undergoing a transition, this will
+        queue up the new transition.
+
+        Returns a future, which can be used to await the transition.
+        """
 
 
         self.fsmLock.acquire()
         self.fsmLock.acquire()
         try:
         try:
@@ -257,11 +281,13 @@ class FSM(DirectObject):
 
 
             if not self.state:
             if not self.state:
                 # Queue up the request.
                 # Queue up the request.
-                self.__requestQueue.append(PythonUtil.Functor(
-                    self.forceTransition, request, *args))
-                return
+                fut = AsyncFuture()
+                self.__requestQueue.append((PythonUtil.Functor(
+                    self.forceTransition, request, *args), fut))
+                return fut
 
 
-            self.__setState(request, *args)
+            result = self.__setState(request, *args)
+            return result._future or self.__doneFuture
         finally:
         finally:
             self.fsmLock.release()
             self.fsmLock.release()
 
 
@@ -275,6 +301,10 @@ class FSM(DirectObject):
         request is queued up and will be executed when the current
         request is queued up and will be executed when the current
         transition finishes.  Multiple requests will queue up in
         transition finishes.  Multiple requests will queue up in
         sequence.
         sequence.
+
+        The return value of this function can be used in an `await`
+        expression to suspend the current coroutine until the
+        transition is done.
         """
         """
 
 
         self.fsmLock.acquire()
         self.fsmLock.acquire()
@@ -284,12 +314,15 @@ class FSM(DirectObject):
                 self._name, request, str(args)[1:]))
                 self._name, request, str(args)[1:]))
             if not self.state:
             if not self.state:
                 # Queue up the request.
                 # Queue up the request.
-                self.__requestQueue.append(PythonUtil.Functor(
-                    self.demand, request, *args))
-                return
+                fut = AsyncFuture()
+                self.__requestQueue.append((PythonUtil.Functor(
+                    self.demand, request, *args), fut))
+                return fut
 
 
-            if not self.request(request, *args):
+            result = self.request(request, *args)
+            if not result:
                 raise RequestDenied("%s (from state: %s)" % (request, self.state))
                 raise RequestDenied("%s (from state: %s)" % (request, self.state))
+            return result._future or self.__doneFuture
         finally:
         finally:
             self.fsmLock.release()
             self.fsmLock.release()
 
 
@@ -314,7 +347,12 @@ class FSM(DirectObject):
         executing an enterState or exitState function), an
         executing an enterState or exitState function), an
         `AlreadyInTransition` exception is raised (but see `demand()`,
         `AlreadyInTransition` exception is raised (but see `demand()`,
         which will queue these requests up and apply when the
         which will queue these requests up and apply when the
-        transition is complete)."""
+        transition is complete).
+
+        If the previous state's exitFunc or the new state's enterFunc
+        is a coroutine, the state change may not have been applied by
+        the time request() returns, but you can use `await` on the
+        return value to await the transition."""
 
 
         self.fsmLock.acquire()
         self.fsmLock.acquire()
         try:
         try:
@@ -331,7 +369,7 @@ class FSM(DirectObject):
                     result = (result,) + args
                     result = (result,) + args
 
 
                 # Otherwise, assume it's a (name, *args) tuple
                 # Otherwise, assume it's a (name, *args) tuple
-                self.__setState(*result)
+                return self.__setState(*result)
 
 
             return result
             return result
         finally:
         finally:
@@ -441,11 +479,11 @@ class FSM(DirectObject):
         try:
         try:
             if self.stateArray:
             if self.stateArray:
                 if not self.state in self.stateArray:
                 if not self.state in self.stateArray:
-                    self.request(self.stateArray[0])
+                    return self.request(self.stateArray[0])
                 else:
                 else:
                     cur_index = self.stateArray.index(self.state)
                     cur_index = self.stateArray.index(self.state)
                     new_index = (cur_index + 1) % len(self.stateArray)
                     new_index = (cur_index + 1) % len(self.stateArray)
-                    self.request(self.stateArray[new_index], args)
+                    return self.request(self.stateArray[new_index], args)
             else:
             else:
                 assert self.notifier.debug(
                 assert self.notifier.debug(
                                     "stateArray empty. Can't switch to next.")
                                     "stateArray empty. Can't switch to next.")
@@ -459,11 +497,11 @@ class FSM(DirectObject):
         try:
         try:
             if self.stateArray:
             if self.stateArray:
                 if not self.state in self.stateArray:
                 if not self.state in self.stateArray:
-                    self.request(self.stateArray[0])
+                    return self.request(self.stateArray[0])
                 else:
                 else:
                     cur_index = self.stateArray.index(self.state)
                     cur_index = self.stateArray.index(self.state)
                     new_index = (cur_index - 1) % len(self.stateArray)
                     new_index = (cur_index - 1) % len(self.stateArray)
-                    self.request(self.stateArray[new_index], args)
+                    return self.request(self.stateArray[new_index], args)
             else:
             else:
                 assert self.notifier.debug(
                 assert self.notifier.debug(
                                     "stateArray empty. Can't switch to next.")
                                     "stateArray empty. Can't switch to next.")
@@ -471,8 +509,26 @@ class FSM(DirectObject):
             self.fsmLock.release()
             self.fsmLock.release()
 
 
     def __setState(self, newState, *args):
     def __setState(self, newState, *args):
-        # Internal function to change unconditionally to the indicated
-        # state.
+        # Internal function to change unconditionally to the indicated state.
+
+        transition = Transition((newState,) + args)
+
+        # See if we can transition immediately by polling the coroutine.
+        coro = self.__transition(newState, *args)
+        try:
+            coro.send(None)
+        except StopIteration:
+            # We managed to apply this straight away.
+            return transition
+
+        # Continue the state transition in a task.
+        task = PythonTask(coro)
+        mgr = AsyncTaskManager.get_global_ptr()
+        mgr.add(task)
+        transition._future = task
+        return transition
+
+    async def __transition(self, newState, *args):
         assert self.state
         assert self.state
         assert self.notify.debug("%s to state %s." % (self._name, newState))
         assert self.notify.debug("%s to state %s." % (self._name, newState))
 
 
@@ -482,8 +538,13 @@ class FSM(DirectObject):
 
 
         try:
         try:
             if not self.__callFromToFunc(self.oldState, self.newState, *args):
             if not self.__callFromToFunc(self.oldState, self.newState, *args):
-                self.__callExitFunc(self.oldState)
-                self.__callEnterFunc(self.newState, *args)
+                result = self.__callExitFunc(self.oldState)
+                if isinstance(result, types.CoroutineType):
+                    await result
+
+                result = self.__callEnterFunc(self.newState, *args)
+                if isinstance(result, types.CoroutineType):
+                    await result
         except:
         except:
             # If we got an exception during the enter or exit methods,
             # If we got an exception during the enter or exit methods,
             # go directly to state "InternalError" and raise up the
             # go directly to state "InternalError" and raise up the
@@ -503,9 +564,10 @@ class FSM(DirectObject):
         del self.newState
         del self.newState
 
 
         if self.__requestQueue:
         if self.__requestQueue:
-            request = self.__requestQueue.pop(0)
+            request, fut = self.__requestQueue.pop(0)
             assert self.notify.debug("%s continued queued request." % (self._name))
             assert self.notify.debug("%s continued queued request." % (self._name))
-            request()
+            await request()
+            fut.set_result(None)
 
 
     def __callEnterFunc(self, name, *args):
     def __callEnterFunc(self, name, *args):
         # Calls the appropriate enter function when transitioning into
         # Calls the appropriate enter function when transitioning into
@@ -517,7 +579,7 @@ class FSM(DirectObject):
             # If there's no matching enterFoo() function, call
             # If there's no matching enterFoo() function, call
             # defaultEnter() instead.
             # defaultEnter() instead.
             func = self.defaultEnter
             func = self.defaultEnter
-        func(*args)
+        return func(*args)
 
 
     def __callFromToFunc(self, oldState, newState, *args):
     def __callFromToFunc(self, oldState, newState, *args):
         # Calls the appropriate fromTo function when transitioning into
         # Calls the appropriate fromTo function when transitioning into
@@ -540,7 +602,7 @@ class FSM(DirectObject):
             # If there's no matching exitFoo() function, call
             # If there's no matching exitFoo() function, call
             # defaultExit() instead.
             # defaultExit() instead.
             func = self.defaultExit
             func = self.defaultExit
-        func()
+        return func()
 
 
     def __repr__(self):
     def __repr__(self):
         return self.__str__()
         return self.__str__()