Przeglądaj źródła

Add dispose checks to ReaderWriterLockSlim and fix race in currentThreadState initialization

Marek Safar 12 lat temu
rodzic
commit
cbca01b433

+ 23 - 18
mcs/class/System.Core/System.Threading/ReaderWriterLockSlim.cs

@@ -56,7 +56,7 @@ namespace System.Threading {
 		 * that are being made. The 3 lowest bits are used as flag to track "destructive" lock entries
 		 * (i.e attempting to take the write lock with or without having acquired an upgradeable lock beforehand).
 		 * All the remaining bits are intepreted as the actual number of reader currently using the lock
-		 * (which mean the lock is limited to 4294967288 concurrent readers but since it's a high number there
+		 * (which mean the lock is limited to 2^29 concurrent readers but since it's a high number there
 		 * is no overflow safe guard to remain simple).
 		 */
 		int rwlock;
@@ -98,7 +98,7 @@ namespace System.Threading {
 		 * instance are kept here.
 		 */
 		[ThreadStatic]
-		static IDictionary<int, ThreadLockState> currentThreadState;
+		static Dictionary<int, ThreadLockState> currentThreadState;
 
 		/* Rwls tries to use this array as much as possible to quickly retrieve the thread-local
 		 * informations so that it ends up being only an array lookup. When the number of thread
@@ -446,6 +446,12 @@ namespace System.Threading {
 
 		public void Dispose ()
 		{
+			if (disposed)
+				return;
+
+			if (IsReadLockHeld || IsUpgradeableReadLockHeld || IsWriteLockHeld)
+				throw new SynchronizationLockException ("The lock is being disposed while still being used");
+
 			disposed = true;
 		}
 
@@ -519,18 +525,22 @@ namespace System.Threading {
 			get {
 				int tid = Thread.CurrentThread.ManagedThreadId;
 
-				if (tid < fastStateCache.Length)
-					return fastStateCache[tid] == null ? (fastStateCache[tid] = new ThreadLockState ()) : fastStateCache[tid];
+				return tid < fastStateCache.Length ?
+					fastStateCache [tid] ?? (fastStateCache[tid] = new ThreadLockState ()) :
+					GetGlobalThreadState (tid);
+			}
+		}
 
-				if (currentThreadState == null)
-					currentThreadState = new Dictionary<int, ThreadLockState> ();
+		ThreadLockState GetGlobalThreadState (int tid)
+		{
+			if (currentThreadState == null)
+				Interlocked.CompareExchange (ref currentThreadState, new Dictionary<int, ThreadLockState> (), null);
 
-				ThreadLockState state;
-				if (!currentThreadState.TryGetValue (id, out state))
-					currentThreadState[id] = state = new ThreadLockState ();
+			ThreadLockState state;
+			if (!currentThreadState.TryGetValue (id, out state))
+				currentThreadState [id] = state = new ThreadLockState ();
 
-				return state;
-			}
+			return state;
 		}
 
 		bool CheckState (ThreadLockState state, int millisecondsTimeout, LockState validState)
@@ -554,16 +564,11 @@ namespace System.Threading {
 			if (ctstate.Has (validState))
 				return true;
 
-			CheckRecursionAuthorization (ctstate, validState);
-
-			return false;
-		}
-
-		static void CheckRecursionAuthorization (LockState ctstate, LockState desiredState)
-		{
 			// In read mode you can just enter Read recursively
 			if (ctstate == LockState.Read)
 				throw new LockRecursionException ();				
+
+			return false;
 		}
 
 		static int CheckTimeout (TimeSpan timeout)

+ 36 - 0
mcs/class/System.Core/Test/System.Threading/ReaderWriterLockSlimTest.cs

@@ -81,6 +81,42 @@ namespace MonoTests.System.Threading
 			}
 		}
 
+		[Test]
+		public void Dispose_WithReadLock ()
+		{
+			var rwl = new ReaderWriterLockSlim ();
+			rwl.EnterReadLock ();
+			try {
+				rwl.Dispose ();
+				Assert.Fail ("1");
+			} catch (SynchronizationLockException) {
+			}
+		}
+
+		[Test]
+		public void Dispose_WithWriteLock ()
+		{
+			var rwl = new ReaderWriterLockSlim ();
+			rwl.EnterWriteLock ();
+			try {
+				rwl.Dispose ();
+				Assert.Fail ("1");
+			} catch (SynchronizationLockException) {
+			}
+		}
+
+		[Test]
+		public void Dispose_UpgradeableReadLock ()
+		{
+			var rwl = new ReaderWriterLockSlim ();
+			rwl.EnterUpgradeableReadLock ();
+			try {
+				rwl.Dispose ();
+				Assert.Fail ("1");
+			} catch (SynchronizationLockException) {
+			}
+		}
+
 		[Test]
 		public void TryEnterReadLock_OutOfRange ()
 		{