WebSocketHelper.cs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. // <copyright>
  2. // Copyright (c) Microsoft Corporation. All rights reserved.
  3. // </copyright>
  4. namespace System.ServiceModel.Channels
  5. {
  6. using System;
  7. using System.Collections.Generic;
  8. using System.Globalization;
  9. using System.Linq;
  10. using System.Net;
  11. using System.Net.WebSockets;
  12. using System.Runtime;
  13. using System.Runtime.InteropServices;
  14. using System.Security.Cryptography;
  15. using System.Text;
  16. using System.Threading;
  17. using System.Threading.Tasks;
  18. static class WebSocketHelper
  19. {
  20. internal const int OperationNotStarted = 0;
  21. internal const int OperationFinished = 1;
  22. internal const string SecWebSocketKey = "Sec-WebSocket-Key";
  23. internal const string SecWebSocketVersion = "Sec-WebSocket-Version";
  24. internal const string SecWebSocketProtocol = "Sec-WebSocket-Protocol";
  25. internal const string SecWebSocketAccept = "Sec-WebSocket-Accept";
  26. internal const string MaxPendingConnectionsString = "MaxPendingConnections";
  27. internal const string WebSocketTransportSettingsString = "WebSocketTransportSettings";
  28. internal const string CloseOperation = "CloseOperation";
  29. internal const string SendOperation = "SendOperation";
  30. internal const string ReceiveOperation = "ReceiveOperation";
  31. internal static readonly char[] ProtocolSeparators = new char[] { ',' };
  32. const string WebSocketKeyPostString = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
  33. const string SchemeWs = "ws";
  34. const string SchemeWss = "wss";
  35. static readonly int PropertyBufferSize = ((2 * Marshal.SizeOf(typeof(uint))) + Marshal.SizeOf(typeof(bool))) + IntPtr.Size;
  36. static readonly HashSet<char> InvalidSeparatorSet = new HashSet<char>(new char[] { '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ' });
  37. static string currentWebSocketVersion;
  38. internal static string ComputeAcceptHeader(string webSocketKey)
  39. {
  40. Fx.Assert(webSocketKey != null, "webSocketKey should not be null.");
  41. using (SHA1 sha = SHA1.Create())
  42. {
  43. string fullString = webSocketKey + WebSocketHelper.WebSocketKeyPostString;
  44. byte[] bytes = Encoding.UTF8.GetBytes(fullString);
  45. return Convert.ToBase64String(sha.ComputeHash(bytes));
  46. }
  47. }
  48. internal static int ComputeClientBufferSize(long maxReceivedMessageSize)
  49. {
  50. return ComputeInternalBufferSize(maxReceivedMessageSize, false);
  51. }
  52. internal static int ComputeServerBufferSize(long maxReceivedMessageSize)
  53. {
  54. return ComputeInternalBufferSize(maxReceivedMessageSize, true);
  55. }
  56. internal static int GetReceiveBufferSize(long maxReceivedMessageSize)
  57. {
  58. int effectiveMaxReceiveBufferSize = maxReceivedMessageSize <= WebSocketDefaults.BufferSize ? (int)maxReceivedMessageSize : WebSocketDefaults.BufferSize;
  59. return Math.Max(WebSocketDefaults.MinReceiveBufferSize, effectiveMaxReceiveBufferSize);
  60. }
  61. internal static bool UseWebSocketTransport(WebSocketTransportUsage transportUsage, bool isContractDuplex)
  62. {
  63. return transportUsage == WebSocketTransportUsage.Always
  64. || (transportUsage == WebSocketTransportUsage.WhenDuplex && isContractDuplex);
  65. }
  66. internal static Uri GetWebSocketUri(Uri httpUri)
  67. {
  68. Fx.Assert(httpUri != null, "RemoteAddress.Uri should not be null.");
  69. UriBuilder builder = new UriBuilder(httpUri);
  70. if (Uri.UriSchemeHttp.Equals(httpUri.Scheme, StringComparison.OrdinalIgnoreCase))
  71. {
  72. builder.Scheme = SchemeWs;
  73. }
  74. else
  75. {
  76. Fx.Assert(
  77. Uri.UriSchemeHttps.Equals(httpUri.Scheme, StringComparison.OrdinalIgnoreCase),
  78. "httpUri.Scheme should be http or https.");
  79. builder.Scheme = SchemeWss;
  80. }
  81. return builder.Uri;
  82. }
  83. internal static bool IsWebSocketUri(Uri uri)
  84. {
  85. return uri != null &&
  86. (WebSocketHelper.SchemeWs.Equals(uri.Scheme, StringComparison.OrdinalIgnoreCase) ||
  87. WebSocketHelper.SchemeWss.Equals(uri.Scheme, StringComparison.OrdinalIgnoreCase));
  88. }
  89. internal static Uri NormalizeWsSchemeWithHttpScheme(Uri uri)
  90. {
  91. Fx.Assert(uri != null, "RemoteAddress.Uri should not be null.");
  92. if (!IsWebSocketUri(uri))
  93. {
  94. return uri;
  95. }
  96. UriBuilder builder = new UriBuilder(uri);
  97. switch (uri.Scheme.ToLowerInvariant())
  98. {
  99. case SchemeWs:
  100. builder.Scheme = Uri.UriSchemeHttp;
  101. break;
  102. case SchemeWss:
  103. builder.Scheme = Uri.UriSchemeHttps;
  104. break;
  105. default:
  106. break;
  107. }
  108. return builder.Uri;
  109. }
  110. internal static bool TryParseSubProtocol(string subProtocolValue, out List<string> subProtocolList)
  111. {
  112. subProtocolList = new List<string>();
  113. if (subProtocolValue != null)
  114. {
  115. string[] parsedTokens = subProtocolValue.Split(ProtocolSeparators, StringSplitOptions.RemoveEmptyEntries);
  116. string invalidChar;
  117. for (int i = 0; i < parsedTokens.Length; i++)
  118. {
  119. string token = parsedTokens[i];
  120. if (!string.IsNullOrWhiteSpace(token))
  121. {
  122. token = token.Trim();
  123. if (!IsSubProtocolInvalid(token, out invalidChar))
  124. {
  125. // Note that we could be adding a duplicate to this list. According to the specification the header should not include
  126. // duplicates but we aim to be "robust in what we receive" so we will allow it. The matching code that consumes this list
  127. // will take the first match so duplicates will not affect the outcome of the negotiation process.
  128. subProtocolList.Add(token);
  129. }
  130. else
  131. {
  132. FxTrace.Exception.AsWarning(new WebException(
  133. SR.GetString(SR.WebSocketInvalidProtocolInvalidCharInProtocolString, token, invalidChar)));
  134. return false;
  135. }
  136. }
  137. }
  138. }
  139. return true;
  140. }
  141. internal static bool IsSubProtocolInvalid(string protocol, out string invalidChar)
  142. {
  143. Fx.Assert(protocol != null, "protocol should not be null");
  144. char[] chars = protocol.ToCharArray();
  145. for (int i = 0; i < chars.Length; i++)
  146. {
  147. char ch = chars[i];
  148. if (ch < 0x21 || ch > 0x7e)
  149. {
  150. invalidChar = string.Format(CultureInfo.InvariantCulture, "[{0}]", (int)ch);
  151. return true;
  152. }
  153. if (InvalidSeparatorSet.Contains(ch))
  154. {
  155. invalidChar = ch.ToString();
  156. return true;
  157. }
  158. }
  159. invalidChar = null;
  160. return false;
  161. }
  162. internal static string GetCurrentVersion()
  163. {
  164. if (currentWebSocketVersion == null)
  165. {
  166. WebSocket.RegisterPrefixes();
  167. HttpWebRequest request = (HttpWebRequest)HttpWebRequest.Create("ws://localhost");
  168. string version = request.Headers[WebSocketHelper.SecWebSocketVersion];
  169. Fx.Assert(version != null, "version should not be null.");
  170. currentWebSocketVersion = version.Trim();
  171. }
  172. return currentWebSocketVersion;
  173. }
  174. internal static WebSocketTransportSettings GetRuntimeWebSocketSettings(WebSocketTransportSettings settings)
  175. {
  176. WebSocketTransportSettings runtimeSettings = settings.Clone();
  177. if (runtimeSettings.MaxPendingConnections == WebSocketDefaults.DefaultMaxPendingConnections)
  178. {
  179. runtimeSettings.MaxPendingConnections = WebSocketDefaults.MaxPendingConnectionsCpuCount;
  180. }
  181. return runtimeSettings;
  182. }
  183. internal static bool OSSupportsWebSockets()
  184. {
  185. return OSEnvironmentHelper.IsAtLeast(OSVersion.Win8);
  186. }
  187. [System.Diagnostics.CodeAnalysis.SuppressMessage(FxCop.Category.ReliabilityBasic, FxCop.Rule.WrapExceptionsRule,
  188. Justification = "The exceptions thrown here are already wrapped.")]
  189. internal static void ThrowCorrectException(Exception ex)
  190. {
  191. throw ConvertAndTraceException(ex);
  192. }
  193. [System.Diagnostics.CodeAnalysis.SuppressMessage(FxCop.Category.ReliabilityBasic, FxCop.Rule.WrapExceptionsRule,
  194. Justification = "The exceptions thrown here are already wrapped.")]
  195. internal static void ThrowCorrectException(Exception ex, TimeSpan timeout, string operation)
  196. {
  197. throw ConvertAndTraceException(ex, timeout, operation);
  198. }
  199. internal static Exception ConvertAndTraceException(Exception ex)
  200. {
  201. return ConvertAndTraceException(
  202. ex,
  203. TimeSpan.MinValue, // this is a dummy since operation type is null, so the timespan value won't be used
  204. null);
  205. }
  206. [System.Diagnostics.CodeAnalysis.SuppressMessage(FxCop.Category.ReliabilityBasic, "Reliability103:ThrowWrappedExceptionsRule",
  207. Justification = "The exceptions wrapped here will be thrown out later.")]
  208. internal static Exception ConvertAndTraceException(Exception ex, TimeSpan timeout, string operation)
  209. {
  210. ObjectDisposedException objectDisposedException = ex as ObjectDisposedException;
  211. if (objectDisposedException != null)
  212. {
  213. CommunicationObjectAbortedException communicationObjectAbortedException = new CommunicationObjectAbortedException(ex.Message, ex);
  214. FxTrace.Exception.AsWarning(communicationObjectAbortedException);
  215. return communicationObjectAbortedException;
  216. }
  217. AggregateException aggregationException = ex as AggregateException;
  218. if (aggregationException != null)
  219. {
  220. Exception exception = FxTrace.Exception.AsError<OperationCanceledException>(aggregationException);
  221. OperationCanceledException operationCanceledException = exception as OperationCanceledException;
  222. if (operationCanceledException != null)
  223. {
  224. TimeoutException timeoutException = GetTimeoutException(exception, timeout, operation);
  225. FxTrace.Exception.AsWarning(timeoutException);
  226. return timeoutException;
  227. }
  228. else
  229. {
  230. Exception communicationException = ConvertAggregateExceptionToCommunicationException(aggregationException);
  231. if (communicationException is CommunicationObjectAbortedException)
  232. {
  233. FxTrace.Exception.AsWarning(communicationException);
  234. return communicationException;
  235. }
  236. else
  237. {
  238. return FxTrace.Exception.AsError(communicationException);
  239. }
  240. }
  241. }
  242. WebSocketException webSocketException = ex as WebSocketException;
  243. if (webSocketException != null)
  244. {
  245. switch (webSocketException.WebSocketErrorCode)
  246. {
  247. case WebSocketError.InvalidMessageType:
  248. case WebSocketError.UnsupportedProtocol:
  249. case WebSocketError.UnsupportedVersion:
  250. ex = new ProtocolException(ex.Message, ex);
  251. break;
  252. default:
  253. ex = new CommunicationException(ex.Message, ex);
  254. break;
  255. }
  256. }
  257. return FxTrace.Exception.AsError(ex);
  258. }
  259. [System.Diagnostics.CodeAnalysis.SuppressMessage(FxCop.Category.ReliabilityBasic, "Reliability103",
  260. Justification = "The exceptions will be wrapped by the callers.")]
  261. internal static Exception ConvertAggregateExceptionToCommunicationException(AggregateException ex)
  262. {
  263. Exception exception = FxTrace.Exception.AsError<WebSocketException>(ex);
  264. WebSocketException webSocketException = exception as WebSocketException;
  265. if (webSocketException != null && webSocketException.InnerException != null)
  266. {
  267. HttpListenerException httpListenerException = webSocketException.InnerException as HttpListenerException;
  268. if (httpListenerException != null)
  269. {
  270. return HttpChannelUtilities.CreateCommunicationException(httpListenerException);
  271. }
  272. }
  273. ObjectDisposedException objectDisposedException = exception as ObjectDisposedException;
  274. if (objectDisposedException != null)
  275. {
  276. return new CommunicationObjectAbortedException(exception.Message, exception);
  277. }
  278. return new CommunicationException(exception.Message, exception);
  279. }
  280. internal static void ThrowExceptionOnTaskFailure(Task task, TimeSpan timeout, string operation)
  281. {
  282. if (task.IsFaulted)
  283. {
  284. throw FxTrace.Exception.AsError<CommunicationException>(task.Exception);
  285. }
  286. else if (task.IsCanceled)
  287. {
  288. throw FxTrace.Exception.AsError(GetTimeoutException(null, timeout, operation));
  289. }
  290. }
  291. internal static TimeoutException GetTimeoutException(Exception innerException, TimeSpan timeout, string operation)
  292. {
  293. string errorMsg = string.Empty;
  294. if (operation != null)
  295. {
  296. switch (operation)
  297. {
  298. case WebSocketHelper.CloseOperation:
  299. errorMsg = SR.GetString(SR.CloseTimedOut, timeout);
  300. break;
  301. case WebSocketHelper.SendOperation:
  302. errorMsg = SR.GetString(SR.WebSocketSendTimedOut, timeout);
  303. break;
  304. case WebSocketHelper.ReceiveOperation:
  305. errorMsg = SR.GetString(SR.WebSocketReceiveTimedOut, timeout);
  306. break;
  307. default:
  308. errorMsg = SR.GetString(SR.WebSocketOperationTimedOut, operation, timeout);
  309. break;
  310. }
  311. }
  312. return innerException == null ? new TimeoutException(errorMsg) : new TimeoutException(errorMsg, innerException);
  313. }
  314. private static int ComputeInternalBufferSize(long maxReceivedMessageSize, bool isServerBuffer)
  315. {
  316. const int NativeOverheadBufferSize = 144;
  317. /* LAYOUT:
  318. | Native buffer | PayloadReceiveBuffer | PropertyBuffer |
  319. | RBS + SBS + 144 | RBS | PBS |
  320. | Only WSPC may modify | Only WebSocketBase may modify |
  321. *RBS = ReceiveBufferSize, *SBS = SendBufferSize
  322. *PBS = PropertyBufferSize (32-bit: 16, 64 bit: 20 bytes) */
  323. int nativeSendBufferSize = isServerBuffer ? WebSocketDefaults.MinSendBufferSize : WebSocketDefaults.BufferSize;
  324. return (2 * GetReceiveBufferSize(maxReceivedMessageSize)) + nativeSendBufferSize + NativeOverheadBufferSize + PropertyBufferSize;
  325. }
  326. }
  327. }