Browse Source

make fsm's thread safe

David Rose 17 years ago
parent
commit
940a46fcfd
1 changed files with 108 additions and 66 deletions
  1. 108 66
      direct/src/fsm/FSM.py

+ 108 - 66
direct/src/fsm/FSM.py

@@ -10,6 +10,7 @@ previously called FSM.py (now called ClassicFSM.py).
 from direct.showbase.DirectObject import DirectObject
 from direct.showbase.DirectObject import DirectObject
 from direct.directnotify import DirectNotifyGlobal
 from direct.directnotify import DirectNotifyGlobal
 from direct.showbase import PythonUtil
 from direct.showbase import PythonUtil
+from direct.stdpy.threading import RLock
 import types
 import types
 import string
 import string
 
 
@@ -142,6 +143,7 @@ class FSM(DirectObject):
     defaultTransitions = None
     defaultTransitions = None
 
 
     def __init__(self, name):
     def __init__(self, name):
+        self.lock = RLock()
         self.name = name
         self.name = name
         self._serialNum = FSM.SerialNum
         self._serialNum = FSM.SerialNum
         FSM.SerialNum += 1
         FSM.SerialNum += 1
@@ -166,9 +168,13 @@ class FSM(DirectObject):
     def cleanup(self):
     def cleanup(self):
         # A convenience function to force the FSM to clean itself up
         # A convenience function to force the FSM to clean itself up
         # by transitioning to the "Off" state.
         # by transitioning to the "Off" state.
-        assert self.state
-        if self.state != 'Off':
-            self.__setState('Off')
+        self.lock.acquire()
+        try:
+            assert self.state
+            if self.state != 'Off':
+                self.__setState('Off')
+        finally:
+            self.lock.release()
 
 
     def setBroadcastStateChanges(self, doBroadcast):
     def setBroadcastStateChanges(self, doBroadcast):
         self._broadcastStateChanges = doBroadcast
         self._broadcastStateChanges = doBroadcast
@@ -183,29 +189,41 @@ class FSM(DirectObject):
         # Returns the current state if we are in a state now, or the
         # Returns the current state if we are in a state now, or the
         # state we are transitioning into if we are currently within
         # state we are transitioning into if we are currently within
         # the enter or exit function for a state.
         # the enter or exit function for a state.
-        if self.state:
-            return self.state
-        return self.newState
+        self.lock.acquire()
+        try:
+            if self.state:
+                return self.state
+            return self.newState
+        finally:
+            self.lock.release()
 
 
     def isInTransition(self):
     def isInTransition(self):
-        return self.state == None
+        self.lock.acquire()
+        try:
+            return self.state == None
+        finally:
+            self.lock.release()
     
     
     def forceTransition(self, request, *args):
     def forceTransition(self, request, *args):
         """Changes unconditionally to the indicated state.  This
         """Changes unconditionally to the indicated state.  This
         bypasses the filterState() function, and just calls
         bypasses the filterState() function, and just calls
         exitState() followed by enterState()."""
         exitState() followed by enterState()."""
 
 
-        assert isinstance(request, types.StringTypes)
-        self.notify.debug("%s.forceTransition(%s, %s" % (
-            self.name, request, str(args)[1:]))
+        self.lock.acquire()
+        try:
+            assert isinstance(request, types.StringTypes)
+            self.notify.debug("%s.forceTransition(%s, %s" % (
+                self.name, request, str(args)[1:]))
 
 
-        if not self.state:
-            # Queue up the request.
-            self.__requestQueue.append(PythonUtil.Functor(
-                self.forceTransition, request, *args))
-            return
+            if not self.state:
+                # Queue up the request.
+                self.__requestQueue.append(PythonUtil.Functor(
+                    self.forceTransition, request, *args))
+                return
 
 
-        self.__setState(request, *args)
+            self.__setState(request, *args)
+        finally:
+            self.lock.release()
 
 
     def demand(self, request, *args):
     def demand(self, request, *args):
         """Requests a state transition, by code that does not expect
         """Requests a state transition, by code that does not expect
@@ -219,17 +237,21 @@ class FSM(DirectObject):
         sequence.
         sequence.
         """
         """
 
 
-        assert isinstance(request, types.StringTypes)
-        self.notify.debug("%s.demand(%s, %s" % (
-            self.name, request, str(args)[1:]))
-        if not self.state:
-            # Queue up the request.
-            self.__requestQueue.append(PythonUtil.Functor(
-                self.demand, request, *args))
-            return
-
-        if not self.request(request, *args):
-            raise RequestDenied, "%s (from state: %s)" % (request, self.state)
+        self.lock.acquire()
+        try:
+            assert isinstance(request, types.StringTypes)
+            self.notify.debug("%s.demand(%s, %s" % (
+                self.name, request, str(args)[1:]))
+            if not self.state:
+                # Queue up the request.
+                self.__requestQueue.append(PythonUtil.Functor(
+                    self.demand, request, *args))
+                return
+
+            if not self.request(request, *args):
+                raise RequestDenied, "%s (from state: %s)" % (request, self.state)
+        finally:
+            self.lock.release()
 
 
     def request(self, request, *args):
     def request(self, request, *args):
         """Requests a state transition (or other behavior).  The
         """Requests a state transition (or other behavior).  The
@@ -254,30 +276,34 @@ class FSM(DirectObject):
         which will queue these requests up and apply when the
         which will queue these requests up and apply when the
         transition is complete)."""
         transition is complete)."""
 
 
-        assert isinstance(request, types.StringTypes)
-        self.notify.debug("%s.request(%s, %s" % (
-            self.name, request, str(args)[1:]))
-
-        if not self.state:
-            error = "requested %s while FSM is in transition from %s to %s." % (request, self.oldState, self.newState)
-            raise AlreadyInTransition, error
-
-        func = getattr(self, "filter" + self.state, None)
-        if not func:
-            # If there's no matching filterState() function, call
-            # defaultFilter() instead.
-            func = self.defaultFilter
-        result = func(request, args)
-        if result:
-            if isinstance(result, types.StringTypes):
-                # If the return value is a string, it's just the name
-                # of the state.  Wrap it in a tuple for consistency.
-                result = (result,) + args
-
-            # Otherwise, assume it's a (name, *args) tuple
-            self.__setState(*result)
-
-        return result
+        self.lock.acquire()
+        try:
+            assert isinstance(request, types.StringTypes)
+            self.notify.debug("%s.request(%s, %s" % (
+                self.name, request, str(args)[1:]))
+
+            if not self.state:
+                error = "requested %s while FSM is in transition from %s to %s." % (request, self.oldState, self.newState)
+                raise AlreadyInTransition, error
+
+            func = getattr(self, "filter" + self.state, None)
+            if not func:
+                # If there's no matching filterState() function, call
+                # defaultFilter() instead.
+                func = self.defaultFilter
+            result = func(request, args)
+            if result:
+                if isinstance(result, types.StringTypes):
+                    # If the return value is a string, it's just the name
+                    # of the state.  Wrap it in a tuple for consistency.
+                    result = (result,) + args
+
+                # Otherwise, assume it's a (name, *args) tuple
+                self.__setState(*result)
+
+            return result
+        finally:
+            self.lock.release()
 
 
     def defaultEnter(self, *args):
     def defaultEnter(self, *args):
         """ This is the default function that is called if there is no
         """ This is the default function that is called if there is no
@@ -353,25 +379,37 @@ class FSM(DirectObject):
 
 
     def setStateArray(self, stateArray):
     def setStateArray(self, stateArray):
         """array of unique states to iterate through"""
         """array of unique states to iterate through"""
-        self.stateArray = stateArray
+        self.lock.acquire()
+        try:
+            self.stateArray = stateArray
+        finally:
+            self.lock.release()
 
 
     def requestNext(self, *args):
     def requestNext(self, *args):
         """request the 'next' state in the predefined state array"""
         """request the 'next' state in the predefined state array"""
-        assert self.state in self.stateArray
+        self.lock.acquire()
+        try:
+            assert self.state in self.stateArray
 
 
-        curIndex = self.stateArray.index(self.state)
-        newIndex = (curIndex + 1) % len(self.stateArray)
+            curIndex = self.stateArray.index(self.state)
+            newIndex = (curIndex + 1) % len(self.stateArray)
 
 
-        self.request(self.stateArray[newIndex], args)
+            self.request(self.stateArray[newIndex], args)
+        finally:
+            self.lock.release()
 
 
     def requestPrev(self, *args):
     def requestPrev(self, *args):
         """request the 'previous' state in the predefined state array"""
         """request the 'previous' state in the predefined state array"""
-        assert self.state in self.stateArray
+        self.lock.acquire()
+        try:
+            assert self.state in self.stateArray
 
 
-        curIndex = self.stateArray.index(self.state)
-        newIndex = (curIndex - 1) % len(self.stateArray)
+            curIndex = self.stateArray.index(self.state)
+            newIndex = (curIndex - 1) % len(self.stateArray)
 
 
-        self.request(self.stateArray[newIndex], args)
+            self.request(self.stateArray[newIndex], args)
+        finally:
+            self.lock.release()
         
         
 
 
     def __setState(self, newState, *args):
     def __setState(self, newState, *args):
@@ -441,9 +479,13 @@ class FSM(DirectObject):
         """
         """
         Print out something useful about the fsm
         Print out something useful about the fsm
         """
         """
-        className = self.__class__.__name__
-        if self.state:
-            str = ('%s FSM:%s in state "%s"' % (className, self.name, self.state))
-        else:
-            str = ('%s FSM:%s not in any state' % (className, self.name))
-        return str
+        self.lock.acquire()
+        try:
+            className = self.__class__.__name__
+            if self.state:
+                str = ('%s FSM:%s in state "%s"' % (className, self.name, self.state))
+            else:
+                str = ('%s FSM:%s not in any state' % (className, self.name))
+            return str
+        finally:
+            self.lock.release()