DefaultWebSocketConnectionHandler.cs 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. // <copyright>
  2. // Copyright (c) Microsoft Corporation. All rights reserved.
  3. // </copyright>
  4. namespace System.ServiceModel.Channels
  5. {
  6. using System;
  7. using System.Collections;
  8. using System.Collections.Generic;
  9. using System.Linq;
  10. using System.Net;
  11. using System.Net.Http;
  12. using System.Net.WebSockets;
  13. using System.Runtime;
  14. using System.Threading;
  15. class DefaultWebSocketConnectionHandler : WebSocketConnectionHandler
  16. {
  17. string currentVersion;
  18. string subProtocol;
  19. MessageEncoder encoder;
  20. string transferMode;
  21. bool needToCheckContentType;
  22. bool needToCheckTransferMode;
  23. Func<string, bool> checkVersionFunc;
  24. Func<string, bool> checkContentTypeFunc;
  25. Func<string, bool> checkTransferModeFunc;
  26. public DefaultWebSocketConnectionHandler(string subProtocol, string currentVersion, MessageVersion messageVersion, MessageEncoderFactory encoderFactory, TransferMode transferMode)
  27. {
  28. this.subProtocol = subProtocol;
  29. this.currentVersion = currentVersion;
  30. this.checkVersionFunc = new Func<string, bool>(this.CheckVersion);
  31. if (messageVersion != MessageVersion.None)
  32. {
  33. this.needToCheckContentType = true;
  34. this.encoder = encoderFactory.CreateSessionEncoder();
  35. this.checkContentTypeFunc = new Func<string, bool>(this.CheckContentType);
  36. if (encoderFactory is BinaryMessageEncoderFactory)
  37. {
  38. this.needToCheckTransferMode = true;
  39. this.transferMode = transferMode.ToString();
  40. this.checkTransferModeFunc = new Func<string, bool>(this.CheckTransferMode);
  41. }
  42. }
  43. }
  44. protected internal override HttpResponseMessage AcceptWebSocket(HttpRequestMessage request, CancellationToken cancellationToken)
  45. {
  46. if (!CheckHttpHeader(request, WebSocketHelper.SecWebSocketVersion, this.checkVersionFunc))
  47. {
  48. return GetUpgradeRequiredResponseMessageWithVersion(request, this.currentVersion);
  49. }
  50. if (this.needToCheckContentType)
  51. {
  52. if (!CheckHttpHeader(request, WebSocketTransportSettings.SoapContentTypeHeader, this.checkContentTypeFunc))
  53. {
  54. return this.GetBadRequestResponseMessageWithContentTypeAndTransfermode(request);
  55. }
  56. if (this.needToCheckTransferMode && !CheckHttpHeader(request, WebSocketTransportSettings.BinaryEncoderTransferModeHeader, this.checkTransferModeFunc))
  57. {
  58. return this.GetBadRequestResponseMessageWithContentTypeAndTransfermode(request);
  59. }
  60. }
  61. HttpResponseMessage response = GetWebSocketAcceptedResponseMessage(request);
  62. SubprotocolParseResult subprotocolParseResult = ParseSubprotocolValues(request);
  63. if (subprotocolParseResult.HeaderFound)
  64. {
  65. if (!subprotocolParseResult.HeaderValid)
  66. {
  67. return GetBadRequestResponseMessage(request);
  68. }
  69. string negotiatedProtocol = null;
  70. // match client protocols vs server protocol
  71. foreach (string protocol in subprotocolParseResult.ParsedSubprotocols)
  72. {
  73. if (string.Compare(protocol, this.subProtocol, StringComparison.OrdinalIgnoreCase) == 0)
  74. {
  75. negotiatedProtocol = protocol;
  76. break;
  77. }
  78. }
  79. if (negotiatedProtocol == null)
  80. {
  81. FxTrace.Exception.AsWarning(new WebException(
  82. SR.GetString(SR.WebSocketInvalidProtocolNotInClientList, this.subProtocol, string.Join(", ", subprotocolParseResult.ParsedSubprotocols))));
  83. return GetUpgradeRequiredResponseMessageWithSubProtocol(request, this.subProtocol);
  84. }
  85. // set response header
  86. response.Headers.Remove(WebSocketHelper.SecWebSocketProtocol);
  87. if (negotiatedProtocol != string.Empty)
  88. {
  89. response.Headers.Add(WebSocketHelper.SecWebSocketProtocol, negotiatedProtocol);
  90. }
  91. }
  92. else
  93. {
  94. if (!string.IsNullOrEmpty(this.subProtocol))
  95. {
  96. FxTrace.Exception.AsWarning(new WebException(
  97. SR.GetString(SR.WebSocketInvalidProtocolNoHeader, this.subProtocol, WebSocketHelper.SecWebSocketProtocol)));
  98. return GetUpgradeRequiredResponseMessageWithSubProtocol(request, this.subProtocol);
  99. }
  100. }
  101. return response;
  102. }
  103. static SubprotocolParseResult ParseSubprotocolValues(HttpRequestMessage request)
  104. {
  105. Fx.Assert(request != null, "request should not be null");
  106. IEnumerable<string> clientProtocols = null;
  107. if (request.Headers.TryGetValues(WebSocketHelper.SecWebSocketProtocol, out clientProtocols))
  108. {
  109. List<string> tokenList = new List<string>();
  110. // We may have multiple subprotocol header in the response. We will build up a list with all the subprotocol values.
  111. // There might be duplicated ones inside the list, but it doesn't matter since we will always take the first matching value.
  112. foreach (string headerValue in clientProtocols)
  113. {
  114. List<string> protocolList;
  115. if (WebSocketHelper.TryParseSubProtocol(headerValue, out protocolList))
  116. {
  117. tokenList.AddRange(protocolList);
  118. }
  119. else
  120. {
  121. return SubprotocolParseResult.HeaderInvalid;
  122. }
  123. }
  124. // If this method returns true, we should ensure that clientProtocols always contains at least one entry
  125. if (tokenList.Count == 0)
  126. {
  127. tokenList.Add(string.Empty);
  128. }
  129. return new SubprotocolParseResult(true, true, tokenList);
  130. }
  131. return SubprotocolParseResult.HeaderNotFound;
  132. }
  133. static HttpResponseMessage GetUpgradeRequiredResponseMessageWithSubProtocol(HttpRequestMessage request, string subprotocol)
  134. {
  135. HttpResponseMessage response = GetUpgradeRequiredResponseMessage(request);
  136. if (!string.IsNullOrEmpty(subprotocol))
  137. {
  138. response.Headers.Add(WebSocketHelper.SecWebSocketProtocol, subprotocol);
  139. }
  140. return response;
  141. }
  142. static HttpResponseMessage GetUpgradeRequiredResponseMessageWithVersion(HttpRequestMessage request, string version)
  143. {
  144. HttpResponseMessage response = GetUpgradeRequiredResponseMessage(request);
  145. response.Headers.Add(WebSocketHelper.SecWebSocketVersion, version);
  146. return response;
  147. }
  148. static bool CheckHttpHeader(HttpRequestMessage request, string header, Func<string, bool> validator)
  149. {
  150. Fx.Assert(request != null, "request should not be null.");
  151. Fx.Assert(header != null, "header should not be null.");
  152. Fx.Assert(validator != null, "validator should not be null.");
  153. IEnumerable<string> headerValues;
  154. if (!request.Headers.TryGetValues(header, out headerValues))
  155. {
  156. return false;
  157. }
  158. bool isValid = false;
  159. foreach (string headerValue in headerValues)
  160. {
  161. if (headerValue != null)
  162. {
  163. isValid = validator(headerValue.Trim());
  164. if (!isValid)
  165. {
  166. return false;
  167. }
  168. }
  169. }
  170. return true;
  171. }
  172. bool CheckVersion(string headerValue)
  173. {
  174. Fx.Assert(headerValue != null, "headerValue should not be null.");
  175. return headerValue == this.currentVersion;
  176. }
  177. bool CheckContentType(string headerValue)
  178. {
  179. Fx.Assert(headerValue != null, "headerValue should not be null.");
  180. return this.encoder.IsContentTypeSupported(headerValue);
  181. }
  182. bool CheckTransferMode(string headerValue)
  183. {
  184. Fx.Assert(headerValue != null, "headerValue should not be null.");
  185. return headerValue.Equals(this.transferMode, StringComparison.OrdinalIgnoreCase);
  186. }
  187. HttpResponseMessage GetBadRequestResponseMessageWithContentTypeAndTransfermode(HttpRequestMessage request)
  188. {
  189. Fx.Assert(this.needToCheckContentType, "needToCheckContentType should be true.");
  190. HttpResponseMessage response = GetBadRequestResponseMessage(request);
  191. response.Headers.Add(WebSocketTransportSettings.SoapContentTypeHeader, this.encoder.ContentType);
  192. if (this.needToCheckTransferMode)
  193. {
  194. response.Headers.Add(WebSocketTransportSettings.BinaryEncoderTransferModeHeader, this.transferMode.ToString());
  195. }
  196. return response;
  197. }
  198. struct SubprotocolParseResult
  199. {
  200. public static readonly SubprotocolParseResult HeaderInvalid = new SubprotocolParseResult(true, false, null);
  201. public static readonly SubprotocolParseResult HeaderNotFound = new SubprotocolParseResult(false, false, null);
  202. bool headerFound;
  203. bool headerValid;
  204. IEnumerable<string> parsedSubprotocols;
  205. public SubprotocolParseResult(bool headerFound, bool headerValid, IEnumerable<string> parsedSubprotocols)
  206. {
  207. this.headerFound = headerFound;
  208. this.headerValid = headerValid;
  209. this.parsedSubprotocols = parsedSubprotocols;
  210. }
  211. public bool HeaderFound
  212. {
  213. get { return this.headerFound; }
  214. }
  215. public bool HeaderValid
  216. {
  217. get { return this.headerValid; }
  218. }
  219. public IEnumerable<string> ParsedSubprotocols
  220. {
  221. get { return this.parsedSubprotocols; }
  222. }
  223. }
  224. }
  225. }