Browse Source

make fsm's thread safe

David Rose 17 years ago
parent
commit
940a46fcfd
1 changed files with 108 additions and 66 deletions
  1. 108 66
      direct/src/fsm/FSM.py

+ 108 - 66
direct/src/fsm/FSM.py

@@ -10,6 +10,7 @@ previously called FSM.py (now called ClassicFSM.py).
 from direct.showbase.DirectObject import DirectObject
 from direct.directnotify import DirectNotifyGlobal
 from direct.showbase import PythonUtil
+from direct.stdpy.threading import RLock
 import types
 import string
 
@@ -142,6 +143,7 @@ class FSM(DirectObject):
     defaultTransitions = None
 
     def __init__(self, name):
+        self.lock = RLock()
         self.name = name
         self._serialNum = FSM.SerialNum
         FSM.SerialNum += 1
@@ -166,9 +168,13 @@ class FSM(DirectObject):
     def cleanup(self):
         # A convenience function to force the FSM to clean itself up
         # by transitioning to the "Off" state.
-        assert self.state
-        if self.state != 'Off':
-            self.__setState('Off')
+        self.lock.acquire()
+        try:
+            assert self.state
+            if self.state != 'Off':
+                self.__setState('Off')
+        finally:
+            self.lock.release()
 
     def setBroadcastStateChanges(self, doBroadcast):
         self._broadcastStateChanges = doBroadcast
@@ -183,29 +189,41 @@ class FSM(DirectObject):
         # Returns the current state if we are in a state now, or the
         # state we are transitioning into if we are currently within
         # the enter or exit function for a state.
-        if self.state:
-            return self.state
-        return self.newState
+        self.lock.acquire()
+        try:
+            if self.state:
+                return self.state
+            return self.newState
+        finally:
+            self.lock.release()
 
     def isInTransition(self):
-        return self.state == None
+        self.lock.acquire()
+        try:
+            return self.state == None
+        finally:
+            self.lock.release()
     
     def forceTransition(self, request, *args):
         """Changes unconditionally to the indicated state.  This
         bypasses the filterState() function, and just calls
         exitState() followed by enterState()."""
 
-        assert isinstance(request, types.StringTypes)
-        self.notify.debug("%s.forceTransition(%s, %s" % (
-            self.name, request, str(args)[1:]))
+        self.lock.acquire()
+        try:
+            assert isinstance(request, types.StringTypes)
+            self.notify.debug("%s.forceTransition(%s, %s" % (
+                self.name, request, str(args)[1:]))
 
-        if not self.state:
-            # Queue up the request.
-            self.__requestQueue.append(PythonUtil.Functor(
-                self.forceTransition, request, *args))
-            return
+            if not self.state:
+                # Queue up the request.
+                self.__requestQueue.append(PythonUtil.Functor(
+                    self.forceTransition, request, *args))
+                return
 
-        self.__setState(request, *args)
+            self.__setState(request, *args)
+        finally:
+            self.lock.release()
 
     def demand(self, request, *args):
         """Requests a state transition, by code that does not expect
@@ -219,17 +237,21 @@ class FSM(DirectObject):
         sequence.
         """
 
-        assert isinstance(request, types.StringTypes)
-        self.notify.debug("%s.demand(%s, %s" % (
-            self.name, request, str(args)[1:]))
-        if not self.state:
-            # Queue up the request.
-            self.__requestQueue.append(PythonUtil.Functor(
-                self.demand, request, *args))
-            return
-
-        if not self.request(request, *args):
-            raise RequestDenied, "%s (from state: %s)" % (request, self.state)
+        self.lock.acquire()
+        try:
+            assert isinstance(request, types.StringTypes)
+            self.notify.debug("%s.demand(%s, %s" % (
+                self.name, request, str(args)[1:]))
+            if not self.state:
+                # Queue up the request.
+                self.__requestQueue.append(PythonUtil.Functor(
+                    self.demand, request, *args))
+                return
+
+            if not self.request(request, *args):
+                raise RequestDenied, "%s (from state: %s)" % (request, self.state)
+        finally:
+            self.lock.release()
 
     def request(self, request, *args):
         """Requests a state transition (or other behavior).  The
@@ -254,30 +276,34 @@ class FSM(DirectObject):
         which will queue these requests up and apply when the
         transition is complete)."""
 
-        assert isinstance(request, types.StringTypes)
-        self.notify.debug("%s.request(%s, %s" % (
-            self.name, request, str(args)[1:]))
-
-        if not self.state:
-            error = "requested %s while FSM is in transition from %s to %s." % (request, self.oldState, self.newState)
-            raise AlreadyInTransition, error
-
-        func = getattr(self, "filter" + self.state, None)
-        if not func:
-            # If there's no matching filterState() function, call
-            # defaultFilter() instead.
-            func = self.defaultFilter
-        result = func(request, args)
-        if result:
-            if isinstance(result, types.StringTypes):
-                # If the return value is a string, it's just the name
-                # of the state.  Wrap it in a tuple for consistency.
-                result = (result,) + args
-
-            # Otherwise, assume it's a (name, *args) tuple
-            self.__setState(*result)
-
-        return result
+        self.lock.acquire()
+        try:
+            assert isinstance(request, types.StringTypes)
+            self.notify.debug("%s.request(%s, %s" % (
+                self.name, request, str(args)[1:]))
+
+            if not self.state:
+                error = "requested %s while FSM is in transition from %s to %s." % (request, self.oldState, self.newState)
+                raise AlreadyInTransition, error
+
+            func = getattr(self, "filter" + self.state, None)
+            if not func:
+                # If there's no matching filterState() function, call
+                # defaultFilter() instead.
+                func = self.defaultFilter
+            result = func(request, args)
+            if result:
+                if isinstance(result, types.StringTypes):
+                    # If the return value is a string, it's just the name
+                    # of the state.  Wrap it in a tuple for consistency.
+                    result = (result,) + args
+
+                # Otherwise, assume it's a (name, *args) tuple
+                self.__setState(*result)
+
+            return result
+        finally:
+            self.lock.release()
 
     def defaultEnter(self, *args):
         """ This is the default function that is called if there is no
@@ -353,25 +379,37 @@ class FSM(DirectObject):
 
     def setStateArray(self, stateArray):
         """array of unique states to iterate through"""
-        self.stateArray = stateArray
+        self.lock.acquire()
+        try:
+            self.stateArray = stateArray
+        finally:
+            self.lock.release()
 
     def requestNext(self, *args):
         """request the 'next' state in the predefined state array"""
-        assert self.state in self.stateArray
+        self.lock.acquire()
+        try:
+            assert self.state in self.stateArray
 
-        curIndex = self.stateArray.index(self.state)
-        newIndex = (curIndex + 1) % len(self.stateArray)
+            curIndex = self.stateArray.index(self.state)
+            newIndex = (curIndex + 1) % len(self.stateArray)
 
-        self.request(self.stateArray[newIndex], args)
+            self.request(self.stateArray[newIndex], args)
+        finally:
+            self.lock.release()
 
     def requestPrev(self, *args):
         """request the 'previous' state in the predefined state array"""
-        assert self.state in self.stateArray
+        self.lock.acquire()
+        try:
+            assert self.state in self.stateArray
 
-        curIndex = self.stateArray.index(self.state)
-        newIndex = (curIndex - 1) % len(self.stateArray)
+            curIndex = self.stateArray.index(self.state)
+            newIndex = (curIndex - 1) % len(self.stateArray)
 
-        self.request(self.stateArray[newIndex], args)
+            self.request(self.stateArray[newIndex], args)
+        finally:
+            self.lock.release()
         
 
     def __setState(self, newState, *args):
@@ -441,9 +479,13 @@ class FSM(DirectObject):
         """
         Print out something useful about the fsm
         """
-        className = self.__class__.__name__
-        if self.state:
-            str = ('%s FSM:%s in state "%s"' % (className, self.name, self.state))
-        else:
-            str = ('%s FSM:%s not in any state' % (className, self.name))
-        return str
+        self.lock.acquire()
+        try:
+            className = self.__class__.__name__
+            if self.state:
+                str = ('%s FSM:%s in state "%s"' % (className, self.name, self.state))
+            else:
+                str = ('%s FSM:%s not in any state' % (className, self.name))
+            return str
+        finally:
+            self.lock.release()