wstransport.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. /**
  2. * Copyright (c) 2020-2021 Paul-Louis Ageneau
  3. *
  4. * This Source Code Form is subject to the terms of the Mozilla Public
  5. * License, v. 2.0. If a copy of the MPL was not distributed with this
  6. * file, You can obtain one at https://mozilla.org/MPL/2.0/.
  7. */
  8. #include "wstransport.hpp"
  9. #include "httpproxytransport.hpp"
  10. #include "tcptransport.hpp"
  11. #include "threadpool.hpp"
  12. #include "tlstransport.hpp"
  13. #include "utils.hpp"
  14. #if RTC_ENABLE_WEBSOCKET
  15. #include <algorithm>
  16. #include <chrono>
  17. #include <iostream>
  18. #include <numeric>
  19. #include <random>
  20. #include <regex>
  21. #include <sstream>
  22. #ifdef _WIN32
  23. #include <winsock2.h>
  24. #else
  25. #include <arpa/inet.h>
  26. #endif
  27. #ifndef htonll
  28. #define htonll(x) \
  29. ((uint64_t)(((uint64_t)htonl((uint32_t)(x))) << 32) | (uint64_t)htonl((uint32_t)((x) >> 32)))
  30. #endif
  31. #ifndef ntohll
  32. #define ntohll(x) htonll(x)
  33. #endif
  34. namespace rtc::impl {
  35. using std::to_integer;
  36. using std::to_string;
  37. using std::chrono::system_clock;
  38. WsTransport::WsTransport(LowerTransport lower, shared_ptr<WsHandshake> handshake,
  39. const WebSocketConfiguration &config, message_callback recvCallback,
  40. state_callback stateCallback)
  41. : Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
  42. std::move(stateCallback)),
  43. mHandshake(std::move(handshake)),
  44. mIsClient(
  45. std::visit(rtc::overloaded{[](auto l) { return l->isActive(); },
  46. [](shared_ptr<TlsTransport> l) { return l->isClient(); }},
  47. lower)),
  48. mMaxMessageSize(config.maxMessageSize.value_or(DEFAULT_WS_MAX_MESSAGE_SIZE)),
  49. mMaxOutstandingPings(config.maxOutstandingPings.value_or(0)) {
  50. onRecv(std::move(recvCallback));
  51. PLOG_DEBUG << "Initializing WebSocket transport";
  52. }
  53. WsTransport::~WsTransport() { unregisterIncoming(); }
  54. void WsTransport::start() {
  55. registerIncoming();
  56. changeState(State::Connecting);
  57. if (mIsClient)
  58. sendHttpRequest();
  59. }
  60. void WsTransport::stop() { close(); }
  61. bool WsTransport::send(message_ptr message) {
  62. if (state() != State::Connected)
  63. throw std::runtime_error("WebSocket is not open");
  64. if (!message)
  65. return false;
  66. PLOG_VERBOSE << "Send size=" << message->size();
  67. return sendFrame({message->type == Message::String ? TEXT_FRAME : BINARY_FRAME, message->data(),
  68. message->size(), true, mIsClient});
  69. }
  70. void WsTransport::close() {
  71. if (state() != State::Connected)
  72. return;
  73. if (mCloseSent.exchange(true))
  74. return;
  75. PLOG_INFO << "WebSocket closing";
  76. try {
  77. sendFrame({CLOSE, NULL, 0, true, mIsClient});
  78. } catch (const std::exception &e) {
  79. // The connection might not be open anymore
  80. PLOG_DEBUG << "Unable to send WebSocket close frame: " << e.what();
  81. changeState(State::Disconnected);
  82. return;
  83. }
  84. ThreadPool::Instance().schedule(std::chrono::seconds(10),
  85. [this, weak_this = weak_from_this()]() {
  86. if (auto shared_this = weak_this.lock()) {
  87. PLOG_DEBUG << "WebSocket close timeout";
  88. changeState(State::Disconnected);
  89. }
  90. });
  91. }
  92. void WsTransport::incoming(message_ptr message) {
  93. auto s = state();
  94. if (s != State::Connecting && s != State::Connected)
  95. return; // Drop
  96. if (message) {
  97. PLOG_VERBOSE << "Incoming size=" << message->size();
  98. try {
  99. mBuffer.insert(mBuffer.end(), message->begin(), message->end());
  100. if (state() == State::Connecting) {
  101. if (mIsClient) {
  102. if (size_t len =
  103. mHandshake->parseHttpResponse(mBuffer.data(), mBuffer.size())) {
  104. PLOG_INFO << "WebSocket client-side open";
  105. changeState(State::Connected);
  106. mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
  107. }
  108. } else {
  109. if (size_t len = mHandshake->parseHttpRequest(mBuffer.data(), mBuffer.size())) {
  110. PLOG_INFO << "WebSocket server-side open";
  111. sendHttpResponse();
  112. changeState(State::Connected);
  113. mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
  114. }
  115. }
  116. }
  117. if (state() == State::Connected) {
  118. if (message->size() == 0) {
  119. // TCP is idle, send a ping
  120. PLOG_DEBUG << "WebSocket sending ping";
  121. uint32_t dummy = 0;
  122. sendFrame({PING, reinterpret_cast<byte *>(&dummy), 4, true, mIsClient});
  123. addOutstandingPing();
  124. } else {
  125. if (mIgnoreLength > 0) {
  126. size_t len = std::min(mIgnoreLength, mBuffer.size());
  127. mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
  128. mIgnoreLength -= len;
  129. }
  130. if (mIgnoreLength == 0) {
  131. Frame frame;
  132. while (size_t len = parseFrame(mBuffer.data(), mBuffer.size(), frame)) {
  133. recvFrame(frame);
  134. if (len > mBuffer.size()) {
  135. mIgnoreLength = len - mBuffer.size();
  136. mBuffer.clear();
  137. break;
  138. }
  139. mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
  140. }
  141. }
  142. }
  143. }
  144. return;
  145. } catch (const WsHandshake::RequestError &e) {
  146. PLOG_WARNING << e.what();
  147. try {
  148. sendHttpError(e.responseCode());
  149. } catch (const std::exception &e) {
  150. PLOG_WARNING << e.what();
  151. }
  152. } catch (const WsHandshake::Error &e) {
  153. PLOG_WARNING << e.what();
  154. } catch (const std::exception &e) {
  155. PLOG_ERROR << e.what();
  156. }
  157. }
  158. if (state() == State::Connected) {
  159. PLOG_INFO << "WebSocket disconnected";
  160. changeState(State::Disconnected);
  161. recv(nullptr);
  162. } else {
  163. PLOG_ERROR << "WebSocket handshake failed";
  164. changeState(State::Failed);
  165. }
  166. }
  167. bool WsTransport::sendHttpRequest() {
  168. PLOG_DEBUG << "Sending WebSocket HTTP request";
  169. const string request = mHandshake->generateHttpRequest();
  170. auto data = reinterpret_cast<const byte *>(request.data());
  171. return outgoing(make_message(data, data + request.size()));
  172. }
  173. bool WsTransport::sendHttpResponse() {
  174. PLOG_DEBUG << "Sending WebSocket HTTP response";
  175. const string response = mHandshake->generateHttpResponse();
  176. auto data = reinterpret_cast<const byte *>(response.data());
  177. return outgoing(make_message(data, data + response.size()));
  178. }
  179. bool WsTransport::sendHttpError(int code) {
  180. PLOG_WARNING << "Sending WebSocket HTTP error response " << code;
  181. const string response = mHandshake->generateHttpError(code);
  182. auto data = reinterpret_cast<const byte *>(response.data());
  183. return outgoing(make_message(data, data + response.size()));
  184. }
  185. // RFC6455 5.2. Base Framing Protocol
  186. // https://www.rfc-editor.org/rfc/rfc6455.html#section-5.2
  187. //
  188. // 0 1 2 3
  189. // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
  190. // +-+-+-+-+-------+-+-------------+-------------------------------+
  191. // |F|R|R|R| opcode|M| Payload len | Extended payload length |
  192. // |I|S|S|S| (4) |A| (7) | (16/64) |
  193. // |N|V|V|V| |S| | (if payload len==126/127) |
  194. // | |1|2|3| |K| | |
  195. // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
  196. // | Extended payload length continued, if payload len == 127 |
  197. // + - - - - - - - - - - - - - - - +-------------------------------+
  198. // | | Masking-key, if MASK set to 1 |
  199. // +-------------------------------+-------------------------------+
  200. // | Masking-key (continued) | Payload Data |
  201. // +-------------------------------+ - - - - - - - - - - - - - - - +
  202. // : Payload Data continued ... :
  203. // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
  204. // | Payload Data continued ... |
  205. // +---------------------------------------------------------------+
  206. size_t WsTransport::parseFrame(byte *buffer, size_t size, Frame &frame) {
  207. const byte *end = buffer + size;
  208. if (end - buffer < 2)
  209. return 0;
  210. byte *cur = buffer;
  211. auto b1 = to_integer<uint8_t>(*cur++);
  212. auto b2 = to_integer<uint8_t>(*cur++);
  213. frame.fin = (b1 & 0x80) != 0;
  214. frame.mask = (b2 & 0x80) != 0;
  215. frame.opcode = static_cast<Opcode>(b1 & 0x0F);
  216. frame.length = b2 & 0x7F;
  217. if (frame.length == 0x7E) {
  218. if (end - cur < 2)
  219. return 0;
  220. frame.length = ntohs(*reinterpret_cast<const uint16_t *>(cur));
  221. cur += 2;
  222. } else if (frame.length == 0x7F) {
  223. if (end - cur < 8)
  224. return 0;
  225. frame.length = ntohll(*reinterpret_cast<const uint64_t *>(cur));
  226. cur += 8;
  227. }
  228. const byte *maskingKey = nullptr;
  229. if (frame.mask) {
  230. if (end - cur < 4)
  231. return 0;
  232. maskingKey = cur;
  233. cur += 4;
  234. }
  235. const size_t maxControlFrameLength = 125;
  236. const size_t maxFrameLength = std::max(maxControlFrameLength, mMaxMessageSize);
  237. if (size_t(end - cur) < std::min(frame.length, maxFrameLength))
  238. return 0;
  239. size_t length = frame.length;
  240. if (frame.length > maxFrameLength) {
  241. PLOG_WARNING << "WebSocket frame is too large (length=" << frame.length
  242. << "), truncating it";
  243. frame.length = maxFrameLength;
  244. }
  245. frame.payload = cur;
  246. if (maskingKey)
  247. for (size_t i = 0; i < frame.length; ++i)
  248. frame.payload[i] ^= maskingKey[i % 4];
  249. return frame.payload + length - buffer; // can be more than buffer size
  250. }
  251. void WsTransport::recvFrame(const Frame &frame) {
  252. PLOG_DEBUG << "WebSocket received frame: opcode=" << int(frame.opcode)
  253. << ", length=" << frame.length;
  254. switch (frame.opcode) {
  255. case TEXT_FRAME:
  256. case BINARY_FRAME: {
  257. size_t size = frame.length;
  258. if (size > mMaxMessageSize) {
  259. PLOG_WARNING << "WebSocket message is too large, truncating it";
  260. size = mMaxMessageSize;
  261. }
  262. if (!mPartial.empty()) {
  263. PLOG_WARNING << "WebSocket unfinished message: type="
  264. << (mPartialOpcode == TEXT_FRAME ? "text" : "binary")
  265. << ", size=" << mPartial.size();
  266. auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
  267. recv(make_message(mPartial.begin(), mPartial.end(), type));
  268. mPartial.clear();
  269. }
  270. mPartialOpcode = frame.opcode;
  271. if (frame.fin) {
  272. PLOG_DEBUG << "WebSocket finished message: type="
  273. << (frame.opcode == TEXT_FRAME ? "text" : "binary") << ", size=" << size;
  274. auto type = frame.opcode == TEXT_FRAME ? Message::String : Message::Binary;
  275. recv(make_message(frame.payload, frame.payload + size, type));
  276. } else {
  277. mPartial.insert(mPartial.end(), frame.payload, frame.payload + size);
  278. }
  279. break;
  280. }
  281. case CONTINUATION: {
  282. mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
  283. if (mPartial.size() > mMaxMessageSize) {
  284. PLOG_WARNING << "WebSocket message is too large, truncating it";
  285. mPartial.resize(mMaxMessageSize);
  286. }
  287. if (frame.fin) {
  288. PLOG_DEBUG << "WebSocket finished message: type="
  289. << (frame.opcode == TEXT_FRAME ? "text" : "binary")
  290. << ", size=" << mPartial.size();
  291. auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
  292. recv(make_message(mPartial.begin(), mPartial.end(), type));
  293. mPartial.clear();
  294. }
  295. break;
  296. }
  297. case PING: {
  298. PLOG_DEBUG << "WebSocket received ping, sending pong";
  299. sendFrame({PONG, frame.payload, frame.length, true, mIsClient});
  300. break;
  301. }
  302. case PONG: {
  303. PLOG_DEBUG << "WebSocket received pong";
  304. mOutstandingPings = 0;
  305. break;
  306. }
  307. case CLOSE: {
  308. PLOG_INFO << "WebSocket closed";
  309. close();
  310. changeState(State::Disconnected);
  311. break;
  312. }
  313. default: {
  314. PLOG_ERROR << "Unknown WebSocket opcode: " + to_string(frame.opcode);
  315. close();
  316. break;
  317. }
  318. }
  319. }
  320. bool WsTransport::sendFrame(const Frame &frame) {
  321. std::lock_guard lock(mSendMutex);
  322. PLOG_DEBUG << "WebSocket sending frame: opcode=" << int(frame.opcode)
  323. << ", length=" << frame.length;
  324. byte buffer[14];
  325. byte *cur = buffer;
  326. *cur++ = byte((frame.opcode & 0x0F) | (frame.fin ? 0x80 : 0));
  327. if (frame.length < 0x7E) {
  328. *cur++ = byte((frame.length & 0x7F) | (frame.mask ? 0x80 : 0));
  329. } else if (frame.length <= 0xFFFF) {
  330. *cur++ = byte(0x7E | (frame.mask ? 0x80 : 0));
  331. *reinterpret_cast<uint16_t *>(cur) = htons(uint16_t(frame.length));
  332. cur += 2;
  333. } else {
  334. *cur++ = byte(0x7F | (frame.mask ? 0x80 : 0));
  335. *reinterpret_cast<uint64_t *>(cur) = htonll(uint64_t(frame.length));
  336. cur += 8;
  337. }
  338. if (frame.mask) {
  339. byte *maskingKey = reinterpret_cast<byte *>(cur);
  340. auto u = reinterpret_cast<uint8_t *>(maskingKey);
  341. std::generate(u, u + 4, utils::random_bytes_engine());
  342. cur += 4;
  343. for (size_t i = 0; i < frame.length; ++i)
  344. frame.payload[i] ^= maskingKey[i % 4];
  345. }
  346. const size_t length = cur - buffer; // header length
  347. auto message = make_message(length + frame.length);
  348. std::copy(buffer, buffer + length, message->begin()); // header
  349. std::copy(frame.payload, frame.payload + frame.length,
  350. message->begin() + length); // payload
  351. return outgoing(std::move(message));
  352. }
  353. void WsTransport::addOutstandingPing() {
  354. ++mOutstandingPings;
  355. if (mMaxOutstandingPings > 0 && mOutstandingPings > mMaxOutstandingPings) {
  356. changeState(State::Failed);
  357. }
  358. }
  359. } // namespace rtc::impl
  360. #endif