| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267 |
- // <copyright>
- // Copyright (c) Microsoft Corporation. All rights reserved.
- // </copyright>
- namespace System.ServiceModel.Channels
- {
- using System;
- using System.Collections;
- using System.Collections.Generic;
- using System.Linq;
- using System.Net;
- using System.Net.Http;
- using System.Net.WebSockets;
- using System.Runtime;
- using System.Threading;
- class DefaultWebSocketConnectionHandler : WebSocketConnectionHandler
- {
- string currentVersion;
- string subProtocol;
- MessageEncoder encoder;
- string transferMode;
- bool needToCheckContentType;
- bool needToCheckTransferMode;
- Func<string, bool> checkVersionFunc;
- Func<string, bool> checkContentTypeFunc;
- Func<string, bool> checkTransferModeFunc;
- public DefaultWebSocketConnectionHandler(string subProtocol, string currentVersion, MessageVersion messageVersion, MessageEncoderFactory encoderFactory, TransferMode transferMode)
- {
- this.subProtocol = subProtocol;
- this.currentVersion = currentVersion;
- this.checkVersionFunc = new Func<string, bool>(this.CheckVersion);
- if (messageVersion != MessageVersion.None)
- {
- this.needToCheckContentType = true;
- this.encoder = encoderFactory.CreateSessionEncoder();
- this.checkContentTypeFunc = new Func<string, bool>(this.CheckContentType);
- if (encoderFactory is BinaryMessageEncoderFactory)
- {
- this.needToCheckTransferMode = true;
- this.transferMode = transferMode.ToString();
- this.checkTransferModeFunc = new Func<string, bool>(this.CheckTransferMode);
- }
- }
- }
- protected internal override HttpResponseMessage AcceptWebSocket(HttpRequestMessage request, CancellationToken cancellationToken)
- {
- if (!CheckHttpHeader(request, WebSocketHelper.SecWebSocketVersion, this.checkVersionFunc))
- {
- return GetUpgradeRequiredResponseMessageWithVersion(request, this.currentVersion);
- }
- if (this.needToCheckContentType)
- {
- if (!CheckHttpHeader(request, WebSocketTransportSettings.SoapContentTypeHeader, this.checkContentTypeFunc))
- {
- return this.GetBadRequestResponseMessageWithContentTypeAndTransfermode(request);
- }
- if (this.needToCheckTransferMode && !CheckHttpHeader(request, WebSocketTransportSettings.BinaryEncoderTransferModeHeader, this.checkTransferModeFunc))
- {
- return this.GetBadRequestResponseMessageWithContentTypeAndTransfermode(request);
- }
- }
- HttpResponseMessage response = GetWebSocketAcceptedResponseMessage(request);
- SubprotocolParseResult subprotocolParseResult = ParseSubprotocolValues(request);
- if (subprotocolParseResult.HeaderFound)
- {
- if (!subprotocolParseResult.HeaderValid)
- {
- return GetBadRequestResponseMessage(request);
- }
- string negotiatedProtocol = null;
- // match client protocols vs server protocol
- foreach (string protocol in subprotocolParseResult.ParsedSubprotocols)
- {
- if (string.Compare(protocol, this.subProtocol, StringComparison.OrdinalIgnoreCase) == 0)
- {
- negotiatedProtocol = protocol;
- break;
- }
- }
- if (negotiatedProtocol == null)
- {
- FxTrace.Exception.AsWarning(new WebException(
- SR.GetString(SR.WebSocketInvalidProtocolNotInClientList, this.subProtocol, string.Join(", ", subprotocolParseResult.ParsedSubprotocols))));
- return GetUpgradeRequiredResponseMessageWithSubProtocol(request, this.subProtocol);
- }
- // set response header
- response.Headers.Remove(WebSocketHelper.SecWebSocketProtocol);
- if (negotiatedProtocol != string.Empty)
- {
- response.Headers.Add(WebSocketHelper.SecWebSocketProtocol, negotiatedProtocol);
- }
- }
- else
- {
- if (!string.IsNullOrEmpty(this.subProtocol))
- {
- FxTrace.Exception.AsWarning(new WebException(
- SR.GetString(SR.WebSocketInvalidProtocolNoHeader, this.subProtocol, WebSocketHelper.SecWebSocketProtocol)));
- return GetUpgradeRequiredResponseMessageWithSubProtocol(request, this.subProtocol);
- }
- }
- return response;
- }
- static SubprotocolParseResult ParseSubprotocolValues(HttpRequestMessage request)
- {
- Fx.Assert(request != null, "request should not be null");
- IEnumerable<string> clientProtocols = null;
- if (request.Headers.TryGetValues(WebSocketHelper.SecWebSocketProtocol, out clientProtocols))
- {
- List<string> tokenList = new List<string>();
- // We may have multiple subprotocol header in the response. We will build up a list with all the subprotocol values.
- // There might be duplicated ones inside the list, but it doesn't matter since we will always take the first matching value.
- foreach (string headerValue in clientProtocols)
- {
- List<string> protocolList;
- if (WebSocketHelper.TryParseSubProtocol(headerValue, out protocolList))
- {
- tokenList.AddRange(protocolList);
- }
- else
- {
- return SubprotocolParseResult.HeaderInvalid;
- }
- }
- // If this method returns true, we should ensure that clientProtocols always contains at least one entry
- if (tokenList.Count == 0)
- {
- tokenList.Add(string.Empty);
- }
- return new SubprotocolParseResult(true, true, tokenList);
- }
- return SubprotocolParseResult.HeaderNotFound;
- }
- static HttpResponseMessage GetUpgradeRequiredResponseMessageWithSubProtocol(HttpRequestMessage request, string subprotocol)
- {
- HttpResponseMessage response = GetUpgradeRequiredResponseMessage(request);
- if (!string.IsNullOrEmpty(subprotocol))
- {
- response.Headers.Add(WebSocketHelper.SecWebSocketProtocol, subprotocol);
- }
- return response;
- }
- static HttpResponseMessage GetUpgradeRequiredResponseMessageWithVersion(HttpRequestMessage request, string version)
- {
- HttpResponseMessage response = GetUpgradeRequiredResponseMessage(request);
- response.Headers.Add(WebSocketHelper.SecWebSocketVersion, version);
- return response;
- }
- static bool CheckHttpHeader(HttpRequestMessage request, string header, Func<string, bool> validator)
- {
- Fx.Assert(request != null, "request should not be null.");
- Fx.Assert(header != null, "header should not be null.");
- Fx.Assert(validator != null, "validator should not be null.");
- IEnumerable<string> headerValues;
- if (!request.Headers.TryGetValues(header, out headerValues))
- {
- return false;
- }
- bool isValid = false;
- foreach (string headerValue in headerValues)
- {
- if (headerValue != null)
- {
- isValid = validator(headerValue.Trim());
- if (!isValid)
- {
- return false;
- }
- }
- }
- return true;
- }
- bool CheckVersion(string headerValue)
- {
- Fx.Assert(headerValue != null, "headerValue should not be null.");
- return headerValue == this.currentVersion;
- }
- bool CheckContentType(string headerValue)
- {
- Fx.Assert(headerValue != null, "headerValue should not be null.");
- return this.encoder.IsContentTypeSupported(headerValue);
- }
- bool CheckTransferMode(string headerValue)
- {
- Fx.Assert(headerValue != null, "headerValue should not be null.");
- return headerValue.Equals(this.transferMode, StringComparison.OrdinalIgnoreCase);
- }
-
- HttpResponseMessage GetBadRequestResponseMessageWithContentTypeAndTransfermode(HttpRequestMessage request)
- {
- Fx.Assert(this.needToCheckContentType, "needToCheckContentType should be true.");
- HttpResponseMessage response = GetBadRequestResponseMessage(request);
- response.Headers.Add(WebSocketTransportSettings.SoapContentTypeHeader, this.encoder.ContentType);
- if (this.needToCheckTransferMode)
- {
- response.Headers.Add(WebSocketTransportSettings.BinaryEncoderTransferModeHeader, this.transferMode.ToString());
- }
- return response;
- }
- struct SubprotocolParseResult
- {
- public static readonly SubprotocolParseResult HeaderInvalid = new SubprotocolParseResult(true, false, null);
- public static readonly SubprotocolParseResult HeaderNotFound = new SubprotocolParseResult(false, false, null);
- bool headerFound;
- bool headerValid;
- IEnumerable<string> parsedSubprotocols;
- public SubprotocolParseResult(bool headerFound, bool headerValid, IEnumerable<string> parsedSubprotocols)
- {
- this.headerFound = headerFound;
- this.headerValid = headerValid;
- this.parsedSubprotocols = parsedSubprotocols;
- }
- public bool HeaderFound
- {
- get { return this.headerFound; }
- }
- public bool HeaderValid
- {
- get { return this.headerValid; }
- }
- public IEnumerable<string> ParsedSubprotocols
- {
- get { return this.parsedSubprotocols; }
- }
- }
- }
- }
|