wshandshake.cpp 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. /**
  2. * Copyright (c) 2020-2021 Paul-Louis Ageneau
  3. *
  4. * This Source Code Form is subject to the terms of the Mozilla Public
  5. * License, v. 2.0. If a copy of the MPL was not distributed with this
  6. * file, You can obtain one at https://mozilla.org/MPL/2.0/.
  7. */
  8. #include "wshandshake.hpp"
  9. #include "internals.hpp"
  10. #include "sha.hpp"
  11. #include "utils.hpp"
  12. #if RTC_ENABLE_WEBSOCKET
  13. #include <algorithm>
  14. #include <chrono>
  15. #include <climits>
  16. #include <iostream>
  17. #include <random>
  18. #include <sstream>
  19. using std::string;
  20. namespace rtc::impl {
  21. using std::to_string;
  22. using std::chrono::system_clock;
  23. WsHandshake::WsHandshake() {}
  24. WsHandshake::WsHandshake(string host, string path, std::vector<string> protocols)
  25. : mHost(std::move(host)), mPath(std::move(path)), mProtocols(std::move(protocols)) {
  26. if (mHost.empty())
  27. throw std::invalid_argument("WebSocket HTTP host cannot be empty");
  28. if (mPath.empty())
  29. throw std::invalid_argument("WebSocket HTTP path cannot be empty");
  30. }
  31. string WsHandshake::host() const {
  32. std::unique_lock lock(mMutex);
  33. return mHost;
  34. }
  35. string WsHandshake::path() const {
  36. std::unique_lock lock(mMutex);
  37. return mPath;
  38. }
  39. std::vector<string> WsHandshake::protocols() const {
  40. std::unique_lock lock(mMutex);
  41. return mProtocols;
  42. }
  43. string WsHandshake::generateHttpRequest() {
  44. std::unique_lock lock(mMutex);
  45. mKey = generateKey();
  46. string out = "GET " + mPath +
  47. " HTTP/1.1\r\n"
  48. "Host: " +
  49. mHost +
  50. "\r\n"
  51. "Connection: upgrade\r\n"
  52. "Upgrade: websocket\r\n"
  53. "Sec-WebSocket-Version: 13\r\n"
  54. "Sec-WebSocket-Key: " +
  55. mKey + "\r\n";
  56. if (!mProtocols.empty())
  57. out += "Sec-WebSocket-Protocol: " + utils::implode(mProtocols, ',') + "\r\n";
  58. out += "\r\n";
  59. return out;
  60. }
  61. string WsHandshake::generateHttpResponse() {
  62. std::unique_lock lock(mMutex);
  63. const string out = "HTTP/1.1 101 Switching Protocols\r\n"
  64. "Server: libdatachannel\r\n"
  65. "Connection: upgrade\r\n"
  66. "Upgrade: websocket\r\n"
  67. "Sec-WebSocket-Accept: " +
  68. computeAcceptKey(mKey) + "\r\n\r\n";
  69. return out;
  70. }
  71. namespace {
  72. string GetHttpErrorName(int responseCode) {
  73. switch (responseCode) {
  74. case 400:
  75. return "Bad Request";
  76. case 404:
  77. return "Not Found";
  78. case 405:
  79. return "Method Not Allowed";
  80. case 426:
  81. return "Upgrade Required";
  82. case 500:
  83. return "Internal Server Error";
  84. default:
  85. return "Error";
  86. }
  87. }
  88. } // namespace
  89. string WsHandshake::generateHttpError(int responseCode) {
  90. std::unique_lock lock(mMutex);
  91. const string error = to_string(responseCode) + " " + GetHttpErrorName(responseCode);
  92. const string out = "HTTP/1.1 " + error +
  93. "\r\n"
  94. "Server: libdatachannel\r\n"
  95. "Connection: upgrade\r\n"
  96. "Upgrade: websocket\r\n"
  97. "Content-Type: text/plain\r\n"
  98. "Content-Length: " +
  99. to_string(error.size()) +
  100. "\r\n"
  101. "Access-Control-Allow-Origin: *\r\n\r\n" +
  102. error;
  103. return out;
  104. }
  105. size_t WsHandshake::parseHttpRequest(const byte *buffer, size_t size) {
  106. std::unique_lock lock(mMutex);
  107. std::list<string> lines;
  108. size_t length = parseHttpLines(buffer, size, lines);
  109. if (length == 0)
  110. return 0;
  111. if (lines.empty())
  112. throw RequestError("Invalid HTTP request for WebSocket", 400);
  113. std::istringstream requestLine(std::move(lines.front()));
  114. lines.pop_front();
  115. string method, path, protocol;
  116. requestLine >> method >> path >> protocol;
  117. PLOG_DEBUG << "WebSocket request method=\"" << method << "\", path=\"" << path << "\"";
  118. if (method != "GET")
  119. throw RequestError("Invalid request method \"" + method + "\" for WebSocket", 405);
  120. mPath = std::move(path);
  121. auto headers = parseHttpHeaders(lines);
  122. auto h = headers.find("host");
  123. if (h == headers.end())
  124. throw RequestError("WebSocket host header missing in request", 400);
  125. mHost = std::move(h->second);
  126. h = headers.find("upgrade");
  127. if (h == headers.end())
  128. throw RequestError("WebSocket upgrade header missing in request", 426);
  129. string upgrade;
  130. std::transform(h->second.begin(), h->second.end(), std::back_inserter(upgrade),
  131. [](char c) { return std::tolower(c); });
  132. if (upgrade != "websocket")
  133. throw RequestError("WebSocket upgrade header mismatching", 426);
  134. h = headers.find("sec-websocket-key");
  135. if (h == headers.end())
  136. throw RequestError("WebSocket key header missing in request", 400);
  137. mKey = std::move(h->second);
  138. h = headers.find("sec-websocket-protocol");
  139. if (h != headers.end())
  140. mProtocols = utils::explode(h->second, ',');
  141. return length;
  142. }
  143. size_t WsHandshake::parseHttpResponse(const byte *buffer, size_t size) {
  144. std::unique_lock lock(mMutex);
  145. std::list<string> lines;
  146. size_t length = parseHttpLines(buffer, size, lines);
  147. if (length == 0)
  148. return 0;
  149. if (lines.empty())
  150. throw Error("Invalid HTTP response for WebSocket");
  151. std::istringstream status(std::move(lines.front()));
  152. lines.pop_front();
  153. string protocol;
  154. unsigned int code = 0;
  155. status >> protocol >> code;
  156. PLOG_DEBUG << "WebSocket response code=" << code;
  157. if (code != 101)
  158. throw std::runtime_error("Unexpected response code " + to_string(code) + " for WebSocket");
  159. auto headers = parseHttpHeaders(lines);
  160. auto h = headers.find("upgrade");
  161. if (h == headers.end())
  162. throw Error("WebSocket update header missing");
  163. string upgrade;
  164. std::transform(h->second.begin(), h->second.end(), std::back_inserter(upgrade),
  165. [](char c) { return std::tolower(c); });
  166. if (upgrade != "websocket")
  167. throw Error("WebSocket update header mismatching");
  168. h = headers.find("sec-websocket-accept");
  169. if (h == headers.end())
  170. throw Error("WebSocket accept header missing");
  171. if (h->second != computeAcceptKey(mKey))
  172. throw Error("WebSocket accept header is invalid");
  173. return length;
  174. }
  175. string WsHandshake::generateKey() {
  176. // RFC 6455: The request MUST include a header field with the name Sec-WebSocket-Key. The value
  177. // of this header field MUST be a nonce consisting of a randomly selected 16-byte value that has
  178. // been base64-encoded. [...] The nonce MUST be selected randomly for each connection.
  179. binary key(16);
  180. auto k = reinterpret_cast<uint8_t *>(key.data());
  181. std::generate(k, k + key.size(), utils::random_bytes_engine());
  182. return utils::base64_encode(key);
  183. }
  184. string WsHandshake::computeAcceptKey(const string &key) {
  185. return utils::base64_encode(Sha1(string(key) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"));
  186. }
  187. size_t WsHandshake::parseHttpLines(const byte *buffer, size_t size, std::list<string> &lines) {
  188. lines.clear();
  189. auto begin = reinterpret_cast<const char *>(buffer);
  190. auto end = begin + size;
  191. auto cur = begin;
  192. while (true) {
  193. auto last = cur;
  194. cur = std::find(cur, end, '\n');
  195. if (cur == end)
  196. return 0;
  197. string line(last, cur != begin && *std::prev(cur) == '\r' ? std::prev(cur++) : cur++);
  198. if (line.empty())
  199. break;
  200. lines.emplace_back(std::move(line));
  201. }
  202. return cur - begin;
  203. }
  204. std::multimap<string, string> WsHandshake::parseHttpHeaders(const std::list<string> &lines) {
  205. std::multimap<string, string> headers;
  206. for (const auto &line : lines) {
  207. if (size_t pos = line.find_first_of(':'); pos != string::npos) {
  208. string key = line.substr(0, pos);
  209. string value = "";
  210. if (size_t subPos = line.find_first_not_of(' ', pos + 1); subPos != string::npos )
  211. {
  212. value = line.substr(subPos);
  213. }
  214. std::transform(key.begin(), key.end(), key.begin(),
  215. [](char c) { return std::tolower(c); });
  216. headers.emplace(std::move(key), std::move(value));
  217. } else {
  218. headers.emplace(line, "");
  219. }
  220. }
  221. return headers;
  222. }
  223. WsHandshake::Error::Error(const string &w) : std::runtime_error(w) {}
  224. WsHandshake::RequestError::RequestError(const string &w, int responseCode)
  225. : Error(w), mResponseCode(responseCode) {}
  226. int WsHandshake::RequestError::RequestError::responseCode() const { return mResponseCode; }
  227. } // namespace rtc::impl
  228. #endif