Browse Source

Finish reimplementation of ReaderWriterLockSlim so that it pass unit tests.

Jérémie Laval 15 years ago
parent
commit
dbf53c74c4
1 changed files with 186 additions and 72 deletions
  1. 186 72
      mcs/class/System.Core/System.Threading/ReaderWriterLockSlim.cs

+ 186 - 72
mcs/class/System.Core/System.Threading/ReaderWriterLockSlim.cs

@@ -35,43 +35,63 @@ using System.Threading;
 
 namespace System.Threading {
 
-	[HostProtectionAttribute(SecurityAction.LinkDemand, MayLeakOnAbort = true)]
-	[HostProtectionAttribute(SecurityAction.LinkDemand, Synchronization = true, ExternalThreading = true)]
-	public class ReaderWriterLockSlim : IDisposable 
+	[Flags]
+	internal enum ThreadLockState
+	{
+		None = 0,
+		Read = 1,
+		Write = 2,
+		Upgradable = 4,
+		UpgradedRead = 5,
+		UpgradedWrite = 6
+	}
+
+	internal static class ThreadLockStateExtensions
 	{
-		enum ThreadLockState
+		internal static bool Has (this ThreadLockState state, ThreadLockState value)
 		{
-			None = 0,
-			Read,
-			Write,
-			Upgradable
+			return (state & value) > 0;
 		}
+	}
 
+	[HostProtectionAttribute(SecurityAction.LinkDemand, MayLeakOnAbort = true)]
+	[HostProtectionAttribute(SecurityAction.LinkDemand, Synchronization = true, ExternalThreading = true)]
+	public class ReaderWriterLockSlim : IDisposable
+	{
 		/* Position of each bit isn't really important 
 		 * but their relative order is
 		 */
 		const int RwWaitBit = 0;
-		const int RwWriteBit = 1;
-		const int RwReadBit = 2;
+		const int RwWaitUpgradeBit = 1;
+		const int RwWriteBit = 2;
+		const int RwReadBit = 3;
 
 		const int RwWait = 1;
-		const int RwWrite = 2;
-		const int RwRead = 4;
+		const int RwWaitUpgrade = 2;
+		const int RwWrite = 4;
+		const int RwRead = 8;
 
 		int rwlock;
 		
 		readonly LockRecursionPolicy recursionPolicy;
-		AtomicBoolean upgradableTaken;
+
+		AtomicBoolean upgradableTaken = new AtomicBoolean ();
+		ManualResetEventSlim upgradableEvent = new ManualResetEventSlim (true);
+
+		int numReadWaiters, numUpgradeWaiters, numWriteWaiters;
+		bool disposed;
 
 		[ThreadStatic]
-		IDictionary<ReaderWriterLockSlim, ThreadLockState> currentThreadState;
+		static IDictionary<ReaderWriterLockSlim, ThreadLockState> currentThreadState;
 
-		public ReaderWriterLockSlim (LockRecursionPolicy.None)
+		public ReaderWriterLockSlim () : this (LockRecursionPolicy.NoRecursion)
 		{
 		}
 
 		public ReaderWriterLockSlim (LockRecursionPolicy recursionPolicy)
 		{
+			if (recursionPolicy != LockRecursionPolicy.NoRecursion)
+				throw new NotSupportedException ("Creating a recursion-aware reader-writer lock is not yet supported");
 			this.recursionPolicy = recursionPolicy;
 		}
 
@@ -84,37 +104,51 @@ namespace System.Threading {
 		{
 			if (CheckState (millisecondsTimeout, ThreadLockState.Read))
 				return true;
+
+			// This is downgrading from upgradable, no need for check since
+			// we already have a sort-of read lock that's going to disappear
+			// after user calls ExitUpgradeableReadLock
+			if (CurrentThreadState.Has (ThreadLockState.Upgradable)) {
+				Interlocked.Add (ref rwlock, RwRead);
+				CurrentThreadState = CurrentThreadState ^ ThreadLockState.Read;
+				return true;
+			}
 			
 			Stopwatch sw = Stopwatch.StartNew ();
-			while (millisecondsTimeout < || && sw.ElapsedMilliseconds < millisecondsTimeout) {
-				if ((rwlock & (RwWrite | RwWait)) > 0) {
-					// Should sleep
+			Interlocked.Increment (ref numReadWaiters);
+
+			while (millisecondsTimeout == -1 || sw.ElapsedMilliseconds < millisecondsTimeout) {
+				if ((rwlock & 0x7) > 0) {
+					Thread.Sleep (1);
 					continue;
 				}
-				
-				if ((Interlocked.Add (ref rwlock, RwRead) & (RwWrite | RwWait)) == 0) {
-					CurrentThreadState = ThreadLockState.Read;
+
+				if ((Interlocked.Add (ref rwlock, RwRead) & 0x7) == 0) {
+					CurrentThreadState = CurrentThreadState ^ ThreadLockState.Read;
+					Interlocked.Decrement (ref numReadWaiters);
 					return true;
 				}
 
 				Interlocked.Add (ref rwlock, -RwRead);
-				// Should sleep
+
+				Thread.Sleep (1);
 			}
 
+			Interlocked.Decrement (ref numReadWaiters);
 			return false;
 		}
 
 		public bool TryEnterReadLock (TimeSpan timeout)
 		{
-			return TryEnterReadLock (timeout.TotalMilliseconds);
+			return TryEnterReadLock (CheckTimeout (timeout));
 		}
 
 		public void ExitReadLock ()
 		{
-			if (CurrentThreadState != Read)
+			if (CurrentThreadState != ThreadLockState.Read)
 				throw new SynchronizationLockException ("The current thread has not entered the lock in read mode");
-			
-			CurrentThreadState = None;
+
+			CurrentThreadState = ThreadLockState.None;
 			Interlocked.Add (ref rwlock, -RwRead);
 		}
 
@@ -125,45 +159,54 @@ namespace System.Threading {
 		
 		public bool TryEnterWriteLock (int millisecondsTimeout)
 		{
-			if (CheckState (millisecondsTimeout, ThreadLockState.Write))
+			bool isUpgradable = CurrentThreadState.Has (ThreadLockState.Upgradable);
+			if (CheckState (millisecondsTimeout, isUpgradable ? ThreadLockState.UpgradedWrite : ThreadLockState.Write))
 				return true;
 
 			Stopwatch sw = Stopwatch.StartNew ();
+			Interlocked.Increment (ref numWriteWaiters);
+
+			// If the code goes there that means we had a read lock beforehand
+			if (isUpgradable && rwlock >= RwRead)
+				Interlocked.Add (ref rwlock, -RwRead);
+
+			int stateCheck = isUpgradable ? RwWaitUpgrade : RwWait;
+			int appendValue = RwWait | (isUpgradable ? RwWaitUpgrade : 0);
 
 			while (millisecondsTimeout < 0 || sw.ElapsedMilliseconds < millisecondsTimeout) {
 				int state = rwlock;
 
-				if (state < RwWrite) {
-					if (Interlocked.CompareExchange (ref rwlock, state, RwWrite) == state) {
-						CurrentThreadState = Write;
+				if (state <= stateCheck) {
+					if (Interlocked.CompareExchange (ref rwlock, RwWrite, state) == state) {
+						CurrentThreadState = isUpgradable ? ThreadLockState.UpgradedWrite : ThreadLockState.Write;
+						Interlocked.Decrement (ref numWriteWaiters);
 						return true;
 					}
 					state = rwlock;
 				}
 
-				while ((state & RwWait) == 0 && Interlocked.CompareExchange (ref rwlock, state, state | RwWait) != state)
+				while ((state & RwWait) == 0 && Interlocked.CompareExchange (ref rwlock, state | appendValue, state) == state)
 					state = rwlock;
 
-				while (rwlock > RwWait && (millisecondsTimeout < 0 || sw.ElapsedMilliseconds < millisecondsTimeout)) {
-					// Should wait here
-					
-				}
+				while (rwlock > stateCheck && (millisecondsTimeout < 0 || sw.ElapsedMilliseconds < millisecondsTimeout))
+					Thread.Sleep (1);
 			}
 
+			Interlocked.Decrement (ref numWriteWaiters);
 			return false;
 		}
 
 		public bool TryEnterWriteLock (TimeSpan timeout)
 		{
-			return TryEnterWriteLock (timeout.TotalMilliseconds);
+			return TryEnterWriteLock (CheckTimeout (timeout));
 		}
 
 		public void ExitWriteLock ()
 		{
-			if (CurrentThreadState != Read)
-				throw new SynchronizationLockException ("The current thread has not entered the lock in read mode");
+			if (!CurrentThreadState.Has (ThreadLockState.Write))
+				throw new SynchronizationLockException ("The current thread has not entered the lock in write mode");
 			
-			CurrentThreadState = None;
+			CurrentThreadState = CurrentThreadState ^ ThreadLockState.Write;
 			Interlocked.Add (ref rwlock, -RwWrite);
 		}
 
@@ -181,8 +224,35 @@ namespace System.Threading {
 			if (CheckState (millisecondsTimeout, ThreadLockState.Upgradable))
 				return true;
 
-			if (CurrentThreadState == ThreadLockState.Read)
-				throw new LockRecursionException ("The current thread has already entered read mode");				
+			if (CurrentThreadState.Has (ThreadLockState.Read))
+				throw new LockRecursionException ("The current thread has already entered read mode");
+
+			Stopwatch sw = Stopwatch.StartNew ();
+			Interlocked.Increment (ref numUpgradeWaiters);
+
+			while (!upgradableEvent.IsSet || !upgradableTaken.TryRelaxedSet ()) {
+				if (millisecondsTimeout != -1 && sw.ElapsedMilliseconds > millisecondsTimeout) {
+					Interlocked.Decrement (ref numUpgradeWaiters);
+					return false;
+				}
+
+				upgradableEvent.Wait (ComputeTimeout (millisecondsTimeout, sw));
+			}
+
+			upgradableEvent.Reset ();
+
+			if (TryEnterReadLock (ComputeTimeout (millisecondsTimeout, sw))) {
+				CurrentThreadState = ThreadLockState.Upgradable;
+				Interlocked.Decrement (ref numUpgradeWaiters);
+				return true;
+			}
+
+			upgradableTaken.Value = false;
+			upgradableEvent.Set ();
+
+			Interlocked.Decrement (ref numUpgradeWaiters);
+
+			return false;
 		}
 
 		public bool TryEnterUpgradeableReadLock (TimeSpan timeout)
@@ -192,79 +262,87 @@ namespace System.Threading {
 	       
 		public void ExitUpgradeableReadLock ()
 		{
+			if (!CurrentThreadState.Has (ThreadLockState.Upgradable | ThreadLockState.Read))
+				throw new SynchronizationLockException ("The current thread has not entered the lock in upgradable mode");
 			
-		}
+			upgradableTaken.Value = false;
+			upgradableEvent.Set ();
 
-		bool CheckState (int millisecondsTimeout, ThreadLockState validState)
-		{
-			if (millisecondsTimeout < Timeout.Infinite)
-				throw new ArgumentOutOfRangeException ("millisecondsTimeout");
-			
-			// Detect and prevent recursion
-			if (recursionPolicy == LockRecursionPolicy.None && CurrentThreadState != None)
-				throw new LockRecursionException ("The current thread has already a lock and recursion isn't supported");
-			
-			// If we already had write lock, just return
-			if (CurrentThreadState == validState)
-				return true;
+			CurrentThreadState = CurrentThreadState ^ ThreadLockState.Upgradable;
+			Interlocked.Add (ref rwlock, -RwRead);
 		}
 
 		public void Dispose ()
 		{
-			read_locks = null;
+			disposed = true;
 		}
 
 		public bool IsReadLockHeld {
-			get { return RecursiveReadCount != 0; }
+			get {
+				return rwlock >= RwRead;
+			}
 		}
 		
 		public bool IsWriteLockHeld {
-			get { return RecursiveWriteCount != 0; }
+			get {
+				return (rwlock & RwWrite) > 0;
+			}
 		}
 		
 		public bool IsUpgradeableReadLockHeld {
-			get { return RecursiveUpgradeCount != 0; }
+			get {
+				return upgradableTaken.Value;
+			}
 		}
 
 		public int CurrentReadCount {
-			get { return owners & 0xFFFFFFF; }
+			get {
+				return (rwlock >> RwReadBit) - (IsUpgradeableReadLockHeld ? 1 : 0);
+			}
 		}
 		
 		public int RecursiveReadCount {
 			get {
-				EnterMyLock ();
-				LockDetails ld = GetReadLockDetails (Thread.CurrentThread.ManagedThreadId, false);
-				int count = ld == null ? 0 : ld.ReadLocks;
-				ExitMyLock ();
-				return count;
+				return IsReadLockHeld ? IsUpgradeableReadLockHeld ? 0 : 1 : 0;
 			}
 		}
 
 		public int RecursiveUpgradeCount {
-			get { return upgradable_thread == Thread.CurrentThread ? 1 : 0; }
+			get {
+				return IsUpgradeableReadLockHeld ? 1 : 0;
+			}
 		}
 
 		public int RecursiveWriteCount {
-			get { return write_thread == Thread.CurrentThread ? 1 : 0; }
+			get {
+				return IsWriteLockHeld ? 1 : 0;
+			}
 		}
 
 		public int WaitingReadCount {
-			get { return (int) numReadWaiters; }
+			get {
+				return numReadWaiters;
+			}
 		}
 
 		public int WaitingUpgradeCount {
-			get { return (int) numUpgradeWaiters; }
+			get {
+				return numUpgradeWaiters;
+			}
 		}
 
 		public int WaitingWriteCount {
-			get { return (int) numWriteWaiters; }
+			get {
+				return numWriteWaiters;
+			}
 		}
 
 		public LockRecursionPolicy RecursionPolicy {
-			get { return recursionPolicy; }
+			get {
+				return recursionPolicy;
+			}
 		}
 		
-#region Private methods
 		ThreadLockState CurrentThreadState {
 			get {
 				// TODO: provide a IEqualityComparer thingie to have better hashes
@@ -284,6 +362,42 @@ namespace System.Threading {
 				currentThreadState[this] = value;
 			}
 		}
-#endregion
+
+		bool CheckState (int millisecondsTimeout, ThreadLockState validState)
+		{
+			if (disposed)
+				throw new ObjectDisposedException ("ReaderWriterLockSlim");
+
+			if (millisecondsTimeout < Timeout.Infinite)
+				throw new ArgumentOutOfRangeException ("millisecondsTimeout");
+
+			// Detect and prevent recursion
+			ThreadLockState ctstate = CurrentThreadState;
+
+			if (recursionPolicy == LockRecursionPolicy.NoRecursion)
+				if ((ctstate != ThreadLockState.None && ctstate != ThreadLockState.Upgradable)
+				    || (ctstate == ThreadLockState.Upgradable && validState == ThreadLockState.Upgradable))
+					throw new LockRecursionException ("The current thread has already a lock and recursion isn't supported");
+
+			// If we already had lock, just return
+			if (CurrentThreadState == validState)
+				return true;
+
+			return false;
+		}
+
+		static int CheckTimeout (TimeSpan timeout)
+		{
+			try {
+				return checked ((int)timeout.TotalMilliseconds);
+			} catch (System.OverflowException) {
+				throw new ArgumentOutOfRangeException ("timeout");
+			}
+		}
+
+		static int ComputeTimeout (int millisecondsTimeout, Stopwatch sw)
+		{
+			return millisecondsTimeout == -1 ? -1 : (int)Math.Max (sw.ElapsedMilliseconds - millisecondsTimeout, 1);
+		}
 	}
 }