Browse Source

Merge pull request #829 from symform/httpwebreq-async-ssl

Convert blocking operations in HttpWebRequest and SslClientStream to non-blocking
Sebastien Pouliot 12 years ago
parent
commit
35899c07ff

+ 147 - 3
mcs/class/Mono.Security/Mono.Security.Protocol.Tls/RecordProtocol.cs

@@ -437,10 +437,88 @@ namespace Mono.Security.Protocol.Tls
 
 		public byte[] ReceiveRecord(Stream record)
 		{
+			if (this.context.ReceivedConnectionEnd)
+			{
+				throw new TlsException(
+					AlertDescription.InternalError,
+					"The session is finished and it's no longer valid.");
+			}
+
+			record_processing.Reset ();
+			byte[] recordTypeBuffer = new byte[1];
+
+			int bytesRead = record.Read(recordTypeBuffer, 0, recordTypeBuffer.Length);
+
+			//We're at the end of the stream. Time to bail.
+			if (bytesRead == 0)
+			{
+				return null;
+			}
+
+			// Try to read the Record Content Type
+			int type = recordTypeBuffer[0];
+
+			// Set last handshake message received to None
+			this.context.LastHandshakeMsg = HandshakeType.ClientHello;
 
-			IAsyncResult ar = this.BeginReceiveRecord(record, null, null);
-			return this.EndReceiveRecord(ar);
+			ContentType	contentType	= (ContentType)type;
+			byte[] buffer = this.ReadRecordBuffer(type, record);
+			if (buffer == null)
+			{
+				// record incomplete (at the moment)
+				return null;
+			}
 
+			// Decrypt message contents if needed
+			if (contentType == ContentType.Alert && buffer.Length == 2)
+			{
+			}
+			else if ((this.Context.Read != null) && (this.Context.Read.Cipher != null))
+			{
+				buffer = this.decryptRecordFragment (contentType, buffer);
+				DebugHelper.WriteLine ("Decrypted record data", buffer);
+			}
+
+			// Process record
+			switch (contentType)
+			{
+			case ContentType.Alert:
+				this.ProcessAlert((AlertLevel)buffer [0], (AlertDescription)buffer [1]);
+				if (record.CanSeek) 
+				{
+					// don't reprocess that memory block
+					record.SetLength (0); 
+				}
+				buffer = null;
+				break;
+
+			case ContentType.ChangeCipherSpec:
+				this.ProcessChangeCipherSpec();
+				break;
+
+			case ContentType.ApplicationData:
+				break;
+
+			case ContentType.Handshake:
+				TlsStream message = new TlsStream (buffer);
+				while (!message.EOF)
+				{
+					this.ProcessHandshakeMessage(message);
+				}
+				break;
+
+			case (ContentType)0x80:
+				this.context.HandshakeMessages.Write (buffer);
+				break;
+
+			default:
+				throw new TlsException(
+					AlertDescription.UnexpectedMessage,
+					"Unknown record received from server.");
+			}
+
+			record_processing.Set ();
+			return buffer;
 		}
 
 		private byte[] ReadRecordBuffer (int contentType, Stream record)
@@ -655,6 +733,57 @@ namespace Mono.Security.Protocol.Tls
 			}
 		}
 
+		public void SendChangeCipherSpec(Stream recordStream)
+		{
+			DebugHelper.WriteLine(">>>> Write Change Cipher Spec");
+
+			byte[] record = this.EncodeRecord (ContentType.ChangeCipherSpec, new byte[] { 1 });
+
+			// Send Change Cipher Spec message with the current cipher
+			// or as plain text if this is the initial negotiation
+			recordStream.Write(record, 0, record.Length);
+
+			Context ctx = this.context;
+
+			// Reset sequence numbers
+			ctx.WriteSequenceNumber = 0;
+
+			// all further data sent will be encrypted with the negotiated
+			// security parameters (now the current parameters)
+			if (ctx is ClientContext) {
+				ctx.StartSwitchingSecurityParameters (true);
+			} else {
+				ctx.EndSwitchingSecurityParameters (false);
+			}
+		}
+
+		public IAsyncResult BeginSendChangeCipherSpec(AsyncCallback callback, object state)
+		{
+			DebugHelper.WriteLine (">>>> Write Change Cipher Spec");
+
+			// Send Change Cipher Spec message with the current cipher
+			// or as plain text if this is the initial negotiation
+			return this.BeginSendRecord (ContentType.ChangeCipherSpec, new byte[] { 1 }, callback, state);
+		}
+
+		public void EndSendChangeCipherSpec (IAsyncResult asyncResult)
+		{
+			this.EndSendRecord (asyncResult);
+
+			Context ctx = this.context;
+
+			// Reset sequence numbers
+			ctx.WriteSequenceNumber = 0;
+
+			// all further data sent will be encrypted with the negotiated
+			// security parameters (now the current parameters)
+			if (ctx is ClientContext) {
+				ctx.StartSwitchingSecurityParameters (true);
+			} else {
+				ctx.EndSwitchingSecurityParameters (false);
+			}
+		}
+
 		public IAsyncResult BeginSendRecord(HandshakeType handshakeType, AsyncCallback callback, object state)
 		{
 			HandshakeMessage msg = this.GetMessage(handshakeType);
@@ -793,7 +922,22 @@ namespace Mono.Security.Protocol.Tls
 
 			return record.ToArray();
 		}
-		
+
+		public byte[] EncodeHandshakeRecord(HandshakeType handshakeType)
+		{
+			HandshakeMessage msg = this.GetMessage(handshakeType);
+
+			msg.Process();
+
+			var bytes = this.EncodeRecord (msg.ContentType, msg.EncodeMessage ());
+
+			msg.Update();
+
+			msg.Reset();
+
+			return bytes;
+		}
+				
 		#endregion
 
 		#region Cryptography Methods

+ 280 - 96
mcs/class/Mono.Security/Mono.Security.Protocol.Tls/SslClientStream.cs

@@ -280,139 +280,323 @@ namespace Mono.Security.Protocol.Tls
 					Fig. 1 - Message flow for a full handshake		
 		*/
 
-		internal override IAsyncResult OnBeginNegotiateHandshake(AsyncCallback callback, object state)
+		private void SafeEndReceiveRecord (IAsyncResult ar, bool ignoreEmpty = false)
 		{
-			try
+			byte[] record = this.protocol.EndReceiveRecord (ar);
+			if (!ignoreEmpty && ((record == null) || (record.Length == 0))) {
+				throw new TlsException (
+					AlertDescription.HandshakeFailiure,
+					"The server stopped the handshake.");
+			}
+		}
+
+		private enum NegotiateState
+		{
+			SentClientHello,
+			ReceiveClientHelloResponse,
+			SentCipherSpec,
+			ReceiveCipherSpecResponse,
+			SentKeyExchange,
+			ReceiveFinishResponse,
+			SentFinished,
+		};
+
+		private class NegotiateAsyncResult : IAsyncResult
+		{
+			private object locker = new object ();
+			private AsyncCallback _userCallback;
+			private object _userState;
+			private Exception _asyncException;
+			private ManualResetEvent handle;
+			private NegotiateState _state;
+			private bool completed;
+
+			public NegotiateAsyncResult(AsyncCallback userCallback, object userState, NegotiateState state)
 			{
-				if (this.context.HandshakeState != HandshakeState.None)
-				{
-					this.context.Clear();
-				}
+				_userCallback = userCallback;
+				_userState = userState;
+				_state = state;
+			}
 
-				// Obtain supported cipher suites
-				this.context.SupportedCiphers = CipherSuiteFactory.GetSupportedCiphers(this.context.SecurityProtocol);
+			public NegotiateState State
+			{
+				get { return _state; }
+				set { _state = value; }
+			}
 
-				// Set handshake state
-				this.context.HandshakeState = HandshakeState.Started;
+			public object AsyncState
+			{
+				get { return _userState; }
+			}
 
-				// Send client hello
-				return this.protocol.BeginSendRecord(HandshakeType.ClientHello, callback, state);
+			public Exception AsyncException
+			{
+				get { return _asyncException; }
 			}
-			catch (TlsException ex)
+
+			public bool CompletedWithError
 			{
-				this.protocol.SendAlert(ex.Alert);
+				get {
+					if (!IsCompleted)
+						return false; // Perhaps throw InvalidOperationExcetion?
 
-				throw new IOException("The authentication or decryption has failed.", ex);
+					return null != _asyncException;
+				}
 			}
-			catch (Exception ex)
+
+			public WaitHandle AsyncWaitHandle
 			{
-				this.protocol.SendAlert(AlertDescription.InternalError);
+				get {
+					lock (locker) {
+						if (handle == null)
+							handle = new ManualResetEvent (completed);
+					}
+					return handle;
+				}
+
+			}
+
+			public bool CompletedSynchronously
+			{
+				get { return false; }
+			}
+
+			public bool IsCompleted
+			{
+				get {
+					lock (locker) {
+						return completed;
+					}
+				}
+			}
+
+			public void SetComplete(Exception ex)
+			{
+				lock (locker) {
+					if (completed)
+						return;
 
-				throw new IOException("The authentication or decryption has failed.", ex);
+					completed = true;
+					if (handle != null)
+						handle.Set ();
+
+					if (_userCallback != null)
+						_userCallback.BeginInvoke (this, null, null);
+
+					_asyncException = ex;
+				}
+			}
+
+			public void SetComplete()
+			{
+				SetComplete(null);
 			}
 		}
 
-		private void SafeReceiveRecord (Stream s, bool ignoreEmpty = false)
+		internal override IAsyncResult BeginNegotiateHandshake(AsyncCallback callback, object state)
 		{
-			byte[] record = this.protocol.ReceiveRecord (s);
-			if (!ignoreEmpty && ((record == null) || (record.Length == 0))) {
-				throw new TlsException (
-					AlertDescription.HandshakeFailiure,
-					"The server stopped the handshake.");
+			if (this.context.HandshakeState != HandshakeState.None) {
+				this.context.Clear ();
 			}
+
+			// Obtain supported cipher suites
+			this.context.SupportedCiphers = CipherSuiteFactory.GetSupportedCiphers (this.context.SecurityProtocol);
+
+			// Set handshake state
+			this.context.HandshakeState = HandshakeState.Started;
+
+			NegotiateAsyncResult result = new NegotiateAsyncResult (callback, state, NegotiateState.SentClientHello);
+
+			// Begin sending the client hello
+			this.protocol.BeginSendRecord (HandshakeType.ClientHello, NegotiateAsyncWorker, result);
+
+			return result;
+		}
+
+		internal override void EndNegotiateHandshake (IAsyncResult result)
+		{
+			NegotiateAsyncResult negotiate = result as NegotiateAsyncResult;
+
+			if (negotiate == null)
+				throw new ArgumentNullException ();
+			if (!negotiate.IsCompleted)
+				negotiate.AsyncWaitHandle.WaitOne();
+			if (negotiate.CompletedWithError)
+				throw negotiate.AsyncException;
 		}
 
-		internal override void OnNegotiateHandshakeCallback(IAsyncResult asyncResult)
+		private void NegotiateAsyncWorker (IAsyncResult result)
 		{
-			this.protocol.EndSendRecord(asyncResult);
+			NegotiateAsyncResult negotiate = result.AsyncState as NegotiateAsyncResult;
 
-			// Read server response
-			while (this.context.LastHandshakeMsg != HandshakeType.ServerHelloDone) 
+			try
 			{
-				// Read next record (skip empty, e.g. warnings alerts)
-				SafeReceiveRecord (this.innerStream, true);
+				switch (negotiate.State)
+				{
+				case NegotiateState.SentClientHello:
+					this.protocol.EndSendRecord (result);
 
-				// special case for abbreviated handshake where no ServerHelloDone is sent from the server
-				if (this.context.AbbreviatedHandshake && (this.context.LastHandshakeMsg == HandshakeType.ServerHello))
+					// we are now ready to ready the receive the hello response.
+					negotiate.State = NegotiateState.ReceiveClientHelloResponse;
+
+					// Start reading the client hello response
+					this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
 					break;
-			}
 
-			// the handshake is much easier if we can reuse a previous session settings
-			if (this.context.AbbreviatedHandshake) 
-			{
-				ClientSessionCache.SetContextFromCache (this.context);
-				this.context.Negotiating.Cipher.ComputeKeys ();
-				this.context.Negotiating.Cipher.InitializeCipher ();
+				case NegotiateState.ReceiveClientHelloResponse:
+					this.SafeEndReceiveRecord (result, true);
+
+					if (this.context.LastHandshakeMsg != HandshakeType.ServerHelloDone &&
+						(!this.context.AbbreviatedHandshake || this.context.LastHandshakeMsg != HandshakeType.ServerHello)) {
+						// Read next record (skip empty, e.g. warnings alerts)
+						this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
+						break;
+					}
+
+					// special case for abbreviated handshake where no ServerHelloDone is sent from the server
+					if (this.context.AbbreviatedHandshake) {
+						ClientSessionCache.SetContextFromCache (this.context);
+						this.context.Negotiating.Cipher.ComputeKeys ();
+						this.context.Negotiating.Cipher.InitializeCipher ();
+
+						negotiate.State = NegotiateState.SentCipherSpec;
+
+						// Send Change Cipher Spec message with the current cipher
+						// or as plain text if this is the initial negotiation
+						this.protocol.BeginSendChangeCipherSpec(NegotiateAsyncWorker, negotiate);
+					} else {
+						// Send client certificate if requested
+						// even if the server ask for it it _may_ still be optional
+						bool clientCertificate = this.context.ServerSettings.CertificateRequest;
+
+						using (var memstream = new MemoryStream())
+						{
+							// NOTE: sadly SSL3 and TLS1 differs in how they handle this and
+							// the current design doesn't allow a very cute way to handle 
+							// SSL3 alert warning for NoCertificate (41).
+							if (this.context.SecurityProtocol == SecurityProtocolType.Ssl3)
+							{
+								clientCertificate = ((this.context.ClientSettings.Certificates != null) &&
+									(this.context.ClientSettings.Certificates.Count > 0));
+								// this works well with OpenSSL (but only for SSL3)
+							}
+
+							byte[] record = null;
+
+							if (clientCertificate)
+							{
+								record = this.protocol.EncodeHandshakeRecord(HandshakeType.Certificate);
+								memstream.Write(record, 0, record.Length);
+							}
+
+							// Send Client Key Exchange
+							record = this.protocol.EncodeHandshakeRecord(HandshakeType.ClientKeyExchange);
+							memstream.Write(record, 0, record.Length);
+
+							// Now initialize session cipher with the generated keys
+							this.context.Negotiating.Cipher.InitializeCipher();
+
+							// Send certificate verify if requested (optional)
+							if (clientCertificate && (this.context.ClientSettings.ClientCertificate != null))
+							{
+								record = this.protocol.EncodeHandshakeRecord(HandshakeType.CertificateVerify);
+								memstream.Write(record, 0, record.Length);
+							}
+
+							// send the chnage cipher spec.
+							this.protocol.SendChangeCipherSpec(memstream);
+
+							// Send Finished message
+							record = this.protocol.EncodeHandshakeRecord(HandshakeType.Finished);
+							memstream.Write(record, 0, record.Length);
+
+							negotiate.State = NegotiateState.SentKeyExchange;
+
+							// send all the records.
+							this.innerStream.BeginWrite (memstream.GetBuffer (), 0, (int)memstream.Length, NegotiateAsyncWorker, negotiate);
+						}
+					}
+					break;
 
-				// Send Cipher Spec protocol
-				this.protocol.SendChangeCipherSpec ();
+				case NegotiateState.SentKeyExchange:
+					this.innerStream.EndWrite (result);
 
-				// Read record until server finished is received
-				while (this.context.HandshakeState != HandshakeState.Finished) 
-				{
-					// If all goes well this will process messages:
-					// 		Change Cipher Spec
-					//		Server finished
-					SafeReceiveRecord (this.innerStream);
-				}
+					negotiate.State = NegotiateState.ReceiveFinishResponse;
 
-				// Send Finished message
-				this.protocol.SendRecord (HandshakeType.Finished);
-			}
-			else
-			{
-				// Send client certificate if requested
-				// even if the server ask for it it _may_ still be optional
-				bool clientCertificate = this.context.ServerSettings.CertificateRequest;
-
-				// NOTE: sadly SSL3 and TLS1 differs in how they handle this and
-				// the current design doesn't allow a very cute way to handle 
-				// SSL3 alert warning for NoCertificate (41).
-				if (this.context.SecurityProtocol == SecurityProtocolType.Ssl3)
-				{
-					clientCertificate = ((this.context.ClientSettings.Certificates != null) &&
-						(this.context.ClientSettings.Certificates.Count > 0));
-					// this works well with OpenSSL (but only for SSL3)
-				}
+					this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
 
-				if (clientCertificate)
-				{
-					this.protocol.SendRecord(HandshakeType.Certificate);
-				}
+					break;
 
-				// Send Client Key Exchange
-				this.protocol.SendRecord(HandshakeType.ClientKeyExchange);
+				case NegotiateState.ReceiveFinishResponse:
+					this.SafeEndReceiveRecord (result);
+
+					// Read record until server finished is received
+					if (this.context.HandshakeState != HandshakeState.Finished) {
+						// If all goes well this will process messages:
+						// 		Change Cipher Spec
+						//		Server finished
+						this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
+					}
+					else {
+						// Reset Handshake messages information
+						this.context.HandshakeMessages.Reset ();
+
+						// Clear Key Info
+						this.context.ClearKeyInfo();
+
+						negotiate.SetComplete ();
+					}
+					break;
 
-				// Now initialize session cipher with the generated keys
-				this.context.Negotiating.Cipher.InitializeCipher();
 
-				// Send certificate verify if requested (optional)
-				if (clientCertificate && (this.context.ClientSettings.ClientCertificate != null))
-				{
-					this.protocol.SendRecord(HandshakeType.CertificateVerify);
-				}
+				case NegotiateState.SentCipherSpec:
+					this.protocol.EndSendChangeCipherSpec (result);
 
-				// Send Cipher Spec protocol
-				this.protocol.SendChangeCipherSpec ();
+					negotiate.State = NegotiateState.ReceiveCipherSpecResponse;
 
-				// Send Finished message
-				this.protocol.SendRecord (HandshakeType.Finished);
+					// Start reading the cipher spec response
+					this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
+					break;
 
-				// Read record until server finished is received
-				while (this.context.HandshakeState != HandshakeState.Finished) {
-					// If all goes well this will process messages:
-					// 		Change Cipher Spec
-					//		Server finished
-					SafeReceiveRecord (this.innerStream);
-				}
-			}
+				case NegotiateState.ReceiveCipherSpecResponse:
+					this.SafeEndReceiveRecord (result, true);
+
+					if (this.context.HandshakeState != HandshakeState.Finished)
+					{
+						this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
+					}
+					else
+					{
+						negotiate.State = NegotiateState.SentFinished;
+						this.protocol.BeginSendRecord(HandshakeType.Finished, NegotiateAsyncWorker, negotiate);
+					}
+					break;
+
+				case NegotiateState.SentFinished:
+					this.protocol.EndSendRecord (result);
+
+					// Reset Handshake messages information
+					this.context.HandshakeMessages.Reset ();
 
-			// Reset Handshake messages information
-			this.context.HandshakeMessages.Reset ();
+					// Clear Key Info
+					this.context.ClearKeyInfo();
 
-			// Clear Key Info
-			this.context.ClearKeyInfo();
+					negotiate.SetComplete ();
 
+					break;
+				}
+			}
+			catch (TlsException ex)
+			{
+				// FIXME: should the send alert also be done asynchronously here and below?
+				this.protocol.SendAlert(ex.Alert);
+				negotiate.SetComplete (new IOException("The authentication or decryption has failed.", ex));
+			}
+			catch (Exception ex)
+			{
+				this.protocol.SendAlert(AlertDescription.InternalError);
+				negotiate.SetComplete (new IOException("The authentication or decryption has failed.", ex));
+			}
 		}
 
 		#endregion

+ 2 - 2
mcs/class/Mono.Security/Mono.Security.Protocol.Tls/SslServerStream.cs

@@ -196,7 +196,7 @@ namespace Mono.Security.Protocol.Tls
 					Fig. 1 - Message flow for a full handshake		
 		*/
 
-		internal override IAsyncResult OnBeginNegotiateHandshake(AsyncCallback callback, object state)
+		internal override IAsyncResult BeginNegotiateHandshake(AsyncCallback callback, object state)
 		{
 			// Reset the context if needed
 			if (this.context.HandshakeState != HandshakeState.None)
@@ -215,7 +215,7 @@ namespace Mono.Security.Protocol.Tls
 
 		}
 
-		internal override void OnNegotiateHandshakeCallback(IAsyncResult asyncResult)
+		internal override void EndNegotiateHandshake(IAsyncResult asyncResult)
 		{
 			// Receive Client Hello message and ignore it
 			this.protocol.EndReceiveRecord(asyncResult);

+ 4 - 4
mcs/class/Mono.Security/Mono.Security.Protocol.Tls/SslStreamBase.cs

@@ -96,7 +96,7 @@ namespace Mono.Security.Protocol.Tls
 			{
 				try
 				{
-					this.OnNegotiateHandshakeCallback(asyncResult);
+					this.EndNegotiateHandshake(asyncResult);
 				}
 				catch (TlsException ex)
 				{
@@ -179,8 +179,8 @@ namespace Mono.Security.Protocol.Tls
 
 		#region Abstracts/Virtuals
 
-		internal abstract IAsyncResult OnBeginNegotiateHandshake(AsyncCallback callback, object state);
-		internal abstract void OnNegotiateHandshakeCallback(IAsyncResult asyncResult);
+		internal abstract IAsyncResult BeginNegotiateHandshake (AsyncCallback callback, object state);
+		internal abstract void EndNegotiateHandshake (IAsyncResult result);
 
 		internal abstract X509Certificate OnLocalCertificateSelection(X509CertificateCollection clientCertificates,
 															X509Certificate serverCertificate,
@@ -492,7 +492,7 @@ namespace Mono.Security.Protocol.Tls
 				{
 					if (this.context.HandshakeState == HandshakeState.None)
 					{
-						this.OnBeginNegotiateHandshake(new AsyncCallback(AsyncHandshakeCallback), asyncResult);
+						this.BeginNegotiateHandshake(new AsyncCallback(AsyncHandshakeCallback), asyncResult);
 
 						return true;
 					}

+ 24 - 16
mcs/class/System/System.Net/HttpWebRequest.cs

@@ -1236,7 +1236,7 @@ namespace System.Net
 			}
 		}
 
-		internal void SendRequestHeaders (bool propagate_error)
+		internal byte[] GetRequestHeaders ()
 		{
 			StringBuilder req = new StringBuilder ();
 			string query;
@@ -1258,18 +1258,7 @@ namespace System.Net
 								actualVersion.Major, actualVersion.Minor);
 			req.Append (GetHeaders ());
 			string reqstr = req.ToString ();
-			byte [] bytes = Encoding.UTF8.GetBytes (reqstr);
-			try {
-				writeStream.SetHeaders (bytes);
-			} catch (WebException wexc) {
-				SetWriteStreamError (wexc.Status, wexc);
-				if (propagate_error)
-					throw;
-			} catch (Exception exc) {
-				SetWriteStreamError (WebExceptionStatus.SendFailure, exc);
-				if (propagate_error)
-					throw;
-			}
+			return Encoding.UTF8.GetBytes (reqstr);
 		}
 
 		internal void SetWriteStream (WebConnectionStream stream)
@@ -1284,14 +1273,32 @@ namespace System.Net
 				writeStream.SendChunked = false;
 			}
 
-			SendRequestHeaders (false);
+			byte[] requestHeaders = GetRequestHeaders ();
+			WebAsyncResult result = new WebAsyncResult (new AsyncCallback (SetWriteStreamCB), null);
+			writeStream.SetHeadersAsync (requestHeaders, result);
+		}
 
+		void SetWriteStreamCB(IAsyncResult ar)
+		{
+			WebAsyncResult result = ar as WebAsyncResult;
+
+			if (result.Exception != null) {
+				WebException wexc = result.Exception as WebException;
+				if (wexc != null) {
+					SetWriteStreamError (wexc.Status, wexc);
+					return;
+				}
+				SetWriteStreamError (WebExceptionStatus.SendFailure, result.Exception);
+				return;
+			}
+		
 			haveRequest = true;
-			
+
 			if (bodyBuffer != null) {
 				// The body has been written and buffered. The request "user"
 				// won't write it again, so we must do it.
 				if (ntlm_auth_state != NtlmAuthState.Challenge) {
+					// FIXME: this is a blocking call on the thread pool that could lead to thread pool exhaustion
 					writeStream.Write (bodyBuffer, 0, bodyBufferLength);
 					bodyBuffer = null;
 					writeStream.Close ();
@@ -1299,11 +1306,12 @@ namespace System.Net
 			} else if (method != "HEAD" && method != "GET" && method != "MKCOL" && method != "CONNECT" &&
 					method != "TRACE") {
 				if (getResponseCalled && !writeStream.RequestWritten)
+					// FIXME: this is a blocking call on the thread pool that could lead to thread pool exhaustion
 					writeStream.WriteRequest ();
 			}
 
 			if (asyncWrite != null) {
-				asyncWrite.SetCompleted (false, stream);
+				asyncWrite.SetCompleted (false, writeStream);
 				asyncWrite.DoCallback ();
 				asyncWrite = null;
 			}

+ 40 - 15
mcs/class/System/System.Net/WebConnectionStream.cs

@@ -632,7 +632,7 @@ namespace System.Net
 		{
 		}
 
-		internal void SetHeaders (byte [] buffer)
+		internal void SetHeadersAsync (byte[] buffer, WebAsyncResult result)
 		{
 			if (headersSent)
 				return;
@@ -646,14 +646,44 @@ namespace System.Net
 			               method == "COPY" || method == "MOVE" || method == "LOCK" ||
 			               method == "UNLOCK");
 			if (sendChunked || cl > -1 || no_writestream || webdav) {
-				WriteHeaders ();
+
+				headersSent = true;
+
+				try {
+					result.InnerAsyncResult = cnc.BeginWrite (request, headers, 0, headers.Length, new AsyncCallback(SetHeadersCB), result);
+					if (result.InnerAsyncResult == null) {
+						// when does BeginWrite return null? Is the case when the request is aborted?
+						if (!result.IsCompleted)
+							result.SetCompleted (true, 0);
+						result.DoCallback ();
+					}
+				} catch (Exception exc) {
+					result.SetCompleted (true, exc);
+					result.DoCallback ();
+				}
+			}
+		}
+
+		void SetHeadersCB (IAsyncResult r)
+		{
+			WebAsyncResult result = (WebAsyncResult) r.AsyncState;
+			result.InnerAsyncResult = null;
+			try {
+				cnc.EndWrite2 (request, r);
+				result.SetCompleted (false, 0);
 				if (!initRead) {
 					initRead = true;
 					WebConnection.InitRead (cnc);
 				}
+				long cl = request.ContentLength;
 				if (!sendChunked && cl == 0)
 					requestWritten = true;
+			} catch (WebException e) {
+				result.SetCompleted (false, e);
+			} catch (Exception e) {
+				result.SetCompleted (false, new WebException ("Error writing headers", e, WebExceptionStatus.SendFailure));
 			}
+			result.DoCallback ();
 		}
 
 		internal bool RequestWritten {
@@ -669,17 +699,6 @@ namespace System.Net
 			return (length > 0) ? cnc.BeginWrite (request, bytes, 0, length, cb, state) : null;
 		}
 
-		void WriteHeaders ()
-		{
-			if (headersSent)
-				return;
-
-			headersSent = true;
-			string err_msg = null;
-			if (!cnc.Write (request, headers, 0, headers.Length, ref err_msg))
-				throw new WebException ("Error writing request: " + err_msg, null, WebExceptionStatus.SendFailure, null);
-		}
-
 		internal void WriteRequest ()
 		{
 			if (requestWritten)
@@ -707,9 +726,15 @@ namespace System.Net
 							method == "TRACE");
 				if (!no_writestream)
 					request.InternalContentLength = length;
-				request.SendRequestHeaders (true);
+
+				byte[] requestHeaders = request.GetRequestHeaders ();
+				WebAsyncResult ar = new WebAsyncResult (null, null);
+				SetHeadersAsync (requestHeaders, ar);
+				ar.AsyncWaitHandle.WaitOne ();
+				if (ar.Exception != null)
+					throw ar.Exception;
 			}
-			WriteHeaders ();
+
 			if (cnc.Data.StatusCode != 0 && cnc.Data.StatusCode != 100)
 				return;