wshandshake.cpp 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. /**
  2. * Copyright (c) 2020-2021 Paul-Louis Ageneau
  3. *
  4. * This Source Code Form is subject to the terms of the Mozilla Public
  5. * License, v. 2.0. If a copy of the MPL was not distributed with this
  6. * file, You can obtain one at https://mozilla.org/MPL/2.0/.
  7. */
  8. #include "wshandshake.hpp"
  9. #include "internals.hpp"
  10. #include "sha.hpp"
  11. #include "utils.hpp"
  12. #if RTC_ENABLE_WEBSOCKET
  13. #include <algorithm>
  14. #include <chrono>
  15. #include <climits>
  16. #include <iostream>
  17. #include <random>
  18. #include <sstream>
  19. using std::string;
  20. namespace rtc::impl {
  21. using std::to_string;
  22. using std::chrono::system_clock;
  23. WsHandshake::WsHandshake() {}
  24. WsHandshake::WsHandshake(string host, string path, std::vector<string> protocols)
  25. : mHost(std::move(host)), mPath(std::move(path)), mProtocols(std::move(protocols)) {
  26. if (mHost.empty())
  27. throw std::invalid_argument("WebSocket HTTP host cannot be empty");
  28. if (mPath.empty())
  29. throw std::invalid_argument("WebSocket HTTP path cannot be empty");
  30. }
  31. string WsHandshake::host() const {
  32. std::unique_lock lock(mMutex);
  33. return mHost;
  34. }
  35. string WsHandshake::path() const {
  36. std::unique_lock lock(mMutex);
  37. return mPath;
  38. }
  39. std::vector<string> WsHandshake::protocols() const {
  40. std::unique_lock lock(mMutex);
  41. return mProtocols;
  42. }
  43. string WsHandshake::generateHttpRequest() {
  44. std::unique_lock lock(mMutex);
  45. mKey = generateKey();
  46. string out = "GET " + mPath +
  47. " HTTP/1.1\r\n"
  48. "Host: " +
  49. mHost +
  50. "\r\n"
  51. "Connection: upgrade\r\n"
  52. "Upgrade: websocket\r\n"
  53. "Sec-WebSocket-Version: 13\r\n"
  54. "Sec-WebSocket-Key: " +
  55. mKey + "\r\n";
  56. if (!mProtocols.empty())
  57. out += "Sec-WebSocket-Protocol: " + utils::implode(mProtocols, ',') + "\r\n";
  58. out += "\r\n";
  59. return out;
  60. }
  61. string WsHandshake::generateHttpResponse() {
  62. std::unique_lock lock(mMutex);
  63. const string out = "HTTP/1.1 101 Switching Protocols\r\n"
  64. "Server: libdatachannel\r\n"
  65. "Connection: upgrade\r\n"
  66. "Upgrade: websocket\r\n"
  67. "Sec-WebSocket-Accept: " +
  68. computeAcceptKey(mKey) + "\r\n\r\n";
  69. return out;
  70. }
  71. namespace {
  72. string GetHttpErrorName(int responseCode) {
  73. switch (responseCode) {
  74. case 400:
  75. return "Bad Request";
  76. case 404:
  77. return "Not Found";
  78. case 405:
  79. return "Method Not Allowed";
  80. case 426:
  81. return "Upgrade Required";
  82. case 500:
  83. return "Internal Server Error";
  84. default:
  85. return "Error";
  86. }
  87. }
  88. } // namespace
  89. string WsHandshake::generateHttpError(int responseCode) {
  90. std::unique_lock lock(mMutex);
  91. const string error = to_string(responseCode) + " " + GetHttpErrorName(responseCode);
  92. const string out = "HTTP/1.1 " + error +
  93. "\r\n"
  94. "Server: libdatachannel\r\n"
  95. "Connection: upgrade\r\n"
  96. "Upgrade: websocket\r\n"
  97. "Content-Type: text/plain\r\n"
  98. "Content-Length: " +
  99. to_string(error.size()) +
  100. "\r\n"
  101. "Access-Control-Allow-Origin: *\r\n\r\n" +
  102. error;
  103. return out;
  104. }
  105. size_t WsHandshake::parseHttpRequest(const byte *buffer, size_t size) {
  106. if (!utils::IsHttpRequest(buffer, size))
  107. throw RequestError("Invalid HTTP request for WebSocket", 400);
  108. std::unique_lock lock(mMutex);
  109. std::list<string> lines;
  110. size_t length = parseHttpLines(buffer, size, lines);
  111. if (length == 0)
  112. return 0;
  113. if (lines.empty())
  114. throw RequestError("Invalid HTTP request for WebSocket", 400);
  115. std::istringstream requestLine(std::move(lines.front()));
  116. lines.pop_front();
  117. string method, path, protocol;
  118. requestLine >> method >> path >> protocol;
  119. PLOG_DEBUG << "WebSocket request method=\"" << method << "\", path=\"" << path << "\"";
  120. if (method != "GET")
  121. throw RequestError("Invalid request method \"" + method + "\" for WebSocket", 405);
  122. mPath = std::move(path);
  123. auto headers = parseHttpHeaders(lines);
  124. auto h = headers.find("host");
  125. if (h == headers.end())
  126. throw RequestError("WebSocket host header missing in request", 400);
  127. mHost = std::move(h->second);
  128. h = headers.find("upgrade");
  129. if (h == headers.end())
  130. throw RequestError("WebSocket upgrade header missing in request", 426);
  131. string upgrade;
  132. std::transform(h->second.begin(), h->second.end(), std::back_inserter(upgrade),
  133. [](char c) { return std::tolower(c); });
  134. if (upgrade != "websocket")
  135. throw RequestError("WebSocket upgrade header mismatching", 426);
  136. h = headers.find("sec-websocket-key");
  137. if (h == headers.end())
  138. throw RequestError("WebSocket key header missing in request", 400);
  139. mKey = std::move(h->second);
  140. h = headers.find("sec-websocket-protocol");
  141. if (h != headers.end())
  142. mProtocols = utils::explode(h->second, ',');
  143. return length;
  144. }
  145. size_t WsHandshake::parseHttpResponse(const byte *buffer, size_t size) {
  146. std::unique_lock lock(mMutex);
  147. std::list<string> lines;
  148. size_t length = parseHttpLines(buffer, size, lines);
  149. if (length == 0)
  150. return 0;
  151. if (lines.empty())
  152. throw Error("Invalid HTTP response for WebSocket");
  153. std::istringstream status(std::move(lines.front()));
  154. lines.pop_front();
  155. string protocol;
  156. unsigned int code = 0;
  157. status >> protocol >> code;
  158. PLOG_DEBUG << "WebSocket response code=" << code;
  159. if (code != 101)
  160. throw std::runtime_error("Unexpected response code " + to_string(code) + " for WebSocket");
  161. auto headers = parseHttpHeaders(lines);
  162. auto h = headers.find("upgrade");
  163. if (h == headers.end())
  164. throw Error("WebSocket update header missing");
  165. string upgrade;
  166. std::transform(h->second.begin(), h->second.end(), std::back_inserter(upgrade),
  167. [](char c) { return std::tolower(c); });
  168. if (upgrade != "websocket")
  169. throw Error("WebSocket update header mismatching");
  170. h = headers.find("sec-websocket-accept");
  171. if (h == headers.end())
  172. throw Error("WebSocket accept header missing");
  173. if (h->second != computeAcceptKey(mKey))
  174. throw Error("WebSocket accept header is invalid");
  175. return length;
  176. }
  177. string WsHandshake::generateKey() {
  178. // RFC 6455: The request MUST include a header field with the name Sec-WebSocket-Key. The value
  179. // of this header field MUST be a nonce consisting of a randomly selected 16-byte value that has
  180. // been base64-encoded. [...] The nonce MUST be selected randomly for each connection.
  181. binary key(16);
  182. auto k = reinterpret_cast<uint8_t *>(key.data());
  183. std::generate(k, k + key.size(), utils::random_bytes_engine());
  184. return utils::base64_encode(key);
  185. }
  186. string WsHandshake::computeAcceptKey(const string &key) {
  187. return utils::base64_encode(Sha1(string(key) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"));
  188. }
  189. size_t WsHandshake::parseHttpLines(const byte *buffer, size_t size, std::list<string> &lines) {
  190. lines.clear();
  191. auto begin = reinterpret_cast<const char *>(buffer);
  192. auto end = begin + size;
  193. auto cur = begin;
  194. while (true) {
  195. auto last = cur;
  196. cur = std::find(cur, end, '\n');
  197. if (cur == end)
  198. return 0;
  199. string line(last, cur != begin && *std::prev(cur) == '\r' ? std::prev(cur++) : cur++);
  200. if (line.empty())
  201. break;
  202. lines.emplace_back(std::move(line));
  203. }
  204. return cur - begin;
  205. }
  206. std::multimap<string, string> WsHandshake::parseHttpHeaders(const std::list<string> &lines) {
  207. std::multimap<string, string> headers;
  208. for (const auto &line : lines) {
  209. if (size_t pos = line.find_first_of(':'); pos != string::npos) {
  210. string key = line.substr(0, pos);
  211. string value = "";
  212. if (size_t subPos = line.find_first_not_of(' ', pos + 1); subPos != string::npos )
  213. {
  214. value = line.substr(subPos);
  215. }
  216. std::transform(key.begin(), key.end(), key.begin(),
  217. [](char c) { return std::tolower(c); });
  218. headers.emplace(std::move(key), std::move(value));
  219. } else {
  220. headers.emplace(line, "");
  221. }
  222. }
  223. return headers;
  224. }
  225. WsHandshake::Error::Error(const string &w) : std::runtime_error(w) {}
  226. WsHandshake::RequestError::RequestError(const string &w, int responseCode)
  227. : Error(w), mResponseCode(responseCode) {}
  228. int WsHandshake::RequestError::RequestError::responseCode() const { return mResponseCode; }
  229. } // namespace rtc::impl
  230. #endif