websocket.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. /**
  2. * Copyright (c) 2020-2021 Paul-Louis Ageneau
  3. *
  4. * This Source Code Form is subject to the terms of the Mozilla Public
  5. * License, v. 2.0. If a copy of the MPL was not distributed with this
  6. * file, You can obtain one at https://mozilla.org/MPL/2.0/.
  7. */
  8. #if RTC_ENABLE_WEBSOCKET
  9. #include "websocket.hpp"
  10. #include "common.hpp"
  11. #include "internals.hpp"
  12. #include "processor.hpp"
  13. #include "utils.hpp"
  14. #include "httpproxytransport.hpp"
  15. #include "tcptransport.hpp"
  16. #include "tlstransport.hpp"
  17. #include "verifiedtlstransport.hpp"
  18. #include "wstransport.hpp"
  19. #include <array>
  20. #include <chrono>
  21. #include <regex>
  22. #ifdef _WIN32
  23. #include <winsock2.h>
  24. #endif
  25. namespace rtc::impl {
  26. using namespace std::placeholders;
  27. using namespace std::chrono_literals;
  28. using std::chrono::milliseconds;
  29. WebSocket::WebSocket(optional<Configuration> optConfig, certificate_ptr certificate)
  30. : config(optConfig ? std::move(*optConfig) : Configuration()),
  31. mCertificate(certificate ? std::move(certificate) : std::move(loadCertificate(config))),
  32. mIsSecure(mCertificate != nullptr), mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {
  33. PLOG_VERBOSE << "Creating WebSocket";
  34. if (config.proxyServer) {
  35. if (config.proxyServer->type == ProxyServer::Type::Socks5)
  36. throw std::invalid_argument(
  37. "Proxy server support for WebSocket is not implemented for Socks5");
  38. if (config.proxyServer->username || config.proxyServer->password) {
  39. PLOG_WARNING << "HTTP authentication support for proxy is not implemented";
  40. }
  41. }
  42. }
  43. certificate_ptr WebSocket::loadCertificate(const Configuration& config) {
  44. if (!config.certificatePemFile)
  45. return nullptr;
  46. if (config.keyPemFile)
  47. return std::make_shared<Certificate>(
  48. Certificate::FromFile(*config.certificatePemFile, *config.keyPemFile,
  49. config.keyPemPass.value_or("")));
  50. throw std::invalid_argument(
  51. "Either none or both certificate and key PEM files must be specified");
  52. }
  53. WebSocket::~WebSocket() { PLOG_VERBOSE << "Destroying WebSocket"; }
  54. void WebSocket::open(const string &url) {
  55. PLOG_VERBOSE << "Opening WebSocket to URL: " << url;
  56. if (state != State::Closed)
  57. throw std::logic_error("WebSocket must be closed before opening");
  58. // Modified regex from RFC 3986, see https://www.rfc-editor.org/rfc/rfc3986.html#appendix-B
  59. static const char *rs =
  60. R"(^(([^:.@/?#]+):)?(/{0,2}((([^:@]*)(:([^@]*))?)@)?(([^:/?#]*)(:([^/?#]*))?))?([^?#]*)(\?([^#]*))?(#(.*))?)";
  61. static const std::regex r(rs, std::regex::extended);
  62. std::smatch m;
  63. if (!std::regex_match(url, m, r) || m[10].length() == 0)
  64. throw std::invalid_argument("Invalid WebSocket URL: " + url);
  65. string scheme = m[2];
  66. if (scheme.empty())
  67. scheme = "ws";
  68. if (scheme != "ws" && scheme != "wss")
  69. throw std::invalid_argument("Invalid WebSocket scheme: " + scheme);
  70. mIsSecure = (scheme != "ws");
  71. string username = utils::url_decode(m[6]);
  72. string password = utils::url_decode(m[8]);
  73. if (!username.empty() || !password.empty()) {
  74. PLOG_WARNING << "HTTP authentication support for WebSocket is not implemented";
  75. }
  76. string host;
  77. string hostname = m[10];
  78. string service = m[12];
  79. if (service.empty()) {
  80. service = mIsSecure ? "443" : "80";
  81. host = hostname;
  82. } else {
  83. host = hostname + ':' + service;
  84. }
  85. if (hostname.front() == '[' && hostname.back() == ']') {
  86. // IPv6 literal
  87. hostname.erase(hostname.begin());
  88. hostname.pop_back();
  89. } else {
  90. hostname = utils::url_decode(hostname);
  91. }
  92. string path = m[13];
  93. if (path.empty())
  94. path += '/';
  95. if (string query = m[15]; !query.empty())
  96. path += "?" + query;
  97. mHostname = hostname; // for TLS SNI and Proxy
  98. mService = service; // For proxy
  99. std::atomic_store(&mWsHandshake, std::make_shared<WsHandshake>(host, path, config.protocols));
  100. changeState(State::Connecting);
  101. if (config.proxyServer) {
  102. setTcpTransport(std::make_shared<TcpTransport>(
  103. config.proxyServer->hostname, std::to_string(config.proxyServer->port), nullptr));
  104. } else {
  105. setTcpTransport(std::make_shared<TcpTransport>(hostname, service, nullptr));
  106. }
  107. }
  108. void WebSocket::close() {
  109. auto s = state.load();
  110. if (s == State::Connecting || s == State::Open) {
  111. PLOG_VERBOSE << "Closing WebSocket";
  112. changeState(State::Closing);
  113. if (auto transport = std::atomic_load(&mWsTransport))
  114. transport->stop();
  115. else
  116. remoteClose();
  117. }
  118. }
  119. void WebSocket::remoteClose() {
  120. close();
  121. if (state.load() != State::Closed)
  122. closeTransports();
  123. }
  124. bool WebSocket::isOpen() const { return state == State::Open; }
  125. bool WebSocket::isClosed() const { return state == State::Closed; }
  126. size_t WebSocket::maxMessageSize() const { return config.maxMessageSize.value_or(DEFAULT_WS_MAX_MESSAGE_SIZE); }
  127. optional<message_variant> WebSocket::receive() {
  128. auto next = mRecvQueue.pop();
  129. return next ? std::make_optional(to_variant(std::move(**next))) : nullopt;
  130. }
  131. optional<message_variant> WebSocket::peek() {
  132. auto next = mRecvQueue.peek();
  133. return next ? std::make_optional(to_variant(std::move(**next))) : nullopt;
  134. }
  135. size_t WebSocket::availableAmount() const { return mRecvQueue.amount(); }
  136. bool WebSocket::changeState(State newState) { return state.exchange(newState) != newState; }
  137. bool WebSocket::outgoing(message_ptr message) {
  138. if (state != State::Open || !mWsTransport)
  139. throw std::runtime_error("WebSocket is not open");
  140. if (message->size() > maxMessageSize())
  141. throw std::runtime_error("Message size exceeds limit");
  142. return mWsTransport->send(message);
  143. }
  144. void WebSocket::incoming(message_ptr message) {
  145. if (!message) {
  146. remoteClose();
  147. return;
  148. }
  149. if (message->type == Message::String || message->type == Message::Binary) {
  150. mRecvQueue.push(message);
  151. triggerAvailable(mRecvQueue.size());
  152. }
  153. }
  154. // Helper for WebSocket::initXTransport methods: start and emplace the transport
  155. template <typename T>
  156. shared_ptr<T> emplaceTransport(WebSocket *ws, shared_ptr<T> *member, shared_ptr<T> transport) {
  157. std::atomic_store(member, transport);
  158. try {
  159. transport->start();
  160. } catch (...) {
  161. std::atomic_store(member, decltype(transport)(nullptr));
  162. transport->stop();
  163. throw;
  164. }
  165. if (ws->state == WebSocket::State::Closed) {
  166. std::atomic_store(member, decltype(transport)(nullptr));
  167. transport->stop();
  168. return nullptr;
  169. }
  170. return transport;
  171. }
  172. shared_ptr<TcpTransport> WebSocket::setTcpTransport(shared_ptr<TcpTransport> transport) {
  173. PLOG_VERBOSE << "Starting TCP transport";
  174. if (!transport)
  175. throw std::logic_error("TCP transport is null");
  176. using State = TcpTransport::State;
  177. try {
  178. if (std::atomic_load(&mTcpTransport))
  179. throw std::logic_error("TCP transport is already set");
  180. transport->onBufferedAmount(weak_bind(&WebSocket::triggerBufferedAmount, this, _1));
  181. transport->onStateChange([this, weak_this = weak_from_this()](State transportState) {
  182. auto shared_this = weak_this.lock();
  183. if (!shared_this)
  184. return;
  185. switch (transportState) {
  186. case State::Connected:
  187. if (config.proxyServer)
  188. initProxyTransport();
  189. else if (mIsSecure)
  190. initTlsTransport();
  191. else
  192. initWsTransport();
  193. break;
  194. case State::Failed:
  195. triggerError("TCP connection failed");
  196. remoteClose();
  197. break;
  198. case State::Disconnected:
  199. remoteClose();
  200. break;
  201. default:
  202. // Ignore
  203. break;
  204. }
  205. });
  206. // WS transport sends a ping on read timeout
  207. auto pingInterval = config.pingInterval.value_or(10000ms);
  208. if (pingInterval > milliseconds::zero())
  209. transport->setReadTimeout(pingInterval);
  210. scheduleConnectionTimeout();
  211. return emplaceTransport(this, &mTcpTransport, std::move(transport));
  212. } catch (const std::exception &e) {
  213. PLOG_ERROR << e.what();
  214. remoteClose();
  215. throw std::runtime_error("TCP transport initialization failed");
  216. }
  217. }
  218. shared_ptr<HttpProxyTransport> WebSocket::initProxyTransport() {
  219. PLOG_VERBOSE << "Starting Tcp Proxy transport";
  220. using State = HttpProxyTransport::State;
  221. try {
  222. if (auto transport = std::atomic_load(&mProxyTransport))
  223. return transport;
  224. auto lower = std::atomic_load(&mTcpTransport);
  225. if (!lower)
  226. throw std::logic_error("No underlying TCP transport for Proxy transport");
  227. auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
  228. auto shared_this = weak_this.lock();
  229. if (!shared_this)
  230. return;
  231. switch (transportState) {
  232. case State::Connected:
  233. if (mIsSecure)
  234. initTlsTransport();
  235. else
  236. initWsTransport();
  237. break;
  238. case State::Failed:
  239. triggerError("Proxy connection failed");
  240. remoteClose();
  241. break;
  242. case State::Disconnected:
  243. remoteClose();
  244. break;
  245. default:
  246. // Ignore
  247. break;
  248. }
  249. };
  250. auto transport = std::make_shared<HttpProxyTransport>(
  251. lower, mHostname.value(), mService.value(), stateChangeCallback);
  252. return emplaceTransport(this, &mProxyTransport, std::move(transport));
  253. } catch (const std::exception &e) {
  254. PLOG_ERROR << e.what();
  255. remoteClose();
  256. throw std::runtime_error("Tcp Proxy transport initialization failed");
  257. }
  258. }
  259. shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
  260. PLOG_VERBOSE << "Starting TLS transport";
  261. using State = TlsTransport::State;
  262. try {
  263. if (auto transport = std::atomic_load(&mTlsTransport))
  264. return transport;
  265. variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>> lower;
  266. if (config.proxyServer) {
  267. auto transport = std::atomic_load(&mProxyTransport);
  268. if (!transport)
  269. throw std::logic_error("No underlying proxy transport for TLS transport");
  270. lower = transport;
  271. } else {
  272. auto transport = std::atomic_load(&mTcpTransport);
  273. if (!transport)
  274. throw std::logic_error("No underlying TCP transport for TLS transport");
  275. lower = transport;
  276. }
  277. auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
  278. auto shared_this = weak_this.lock();
  279. if (!shared_this)
  280. return;
  281. switch (transportState) {
  282. case State::Connected:
  283. initWsTransport();
  284. break;
  285. case State::Failed:
  286. triggerError("TLS connection failed");
  287. remoteClose();
  288. break;
  289. case State::Disconnected:
  290. remoteClose();
  291. break;
  292. default:
  293. // Ignore
  294. break;
  295. }
  296. };
  297. bool verify = mHostname.has_value() && !config.disableTlsVerification;
  298. #ifdef _WIN32
  299. if (std::exchange(verify, false)) {
  300. PLOG_WARNING << "TLS certificate verification with root CA is not supported on Windows";
  301. }
  302. #endif
  303. shared_ptr<TlsTransport> transport;
  304. if (verify)
  305. transport = std::make_shared<VerifiedTlsTransport>(lower, mHostname.value(),
  306. mCertificate, stateChangeCallback,
  307. config.caCertificatePemFile);
  308. else
  309. transport =
  310. std::make_shared<TlsTransport>(lower, mHostname, mCertificate, stateChangeCallback);
  311. return emplaceTransport(this, &mTlsTransport, std::move(transport));
  312. } catch (const std::exception &e) {
  313. PLOG_ERROR << e.what();
  314. remoteClose();
  315. throw std::runtime_error("TLS transport initialization failed");
  316. }
  317. }
  318. shared_ptr<WsTransport> WebSocket::initWsTransport() {
  319. PLOG_VERBOSE << "Starting WebSocket transport";
  320. using State = WsTransport::State;
  321. try {
  322. if (auto transport = std::atomic_load(&mWsTransport))
  323. return transport;
  324. variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>, shared_ptr<TlsTransport>>
  325. lower;
  326. if (mIsSecure) {
  327. auto transport = std::atomic_load(&mTlsTransport);
  328. if (!transport)
  329. throw std::logic_error("No underlying TLS transport for WebSocket transport");
  330. lower = transport;
  331. } else if (config.proxyServer) {
  332. auto transport = std::atomic_load(&mProxyTransport);
  333. if (!transport)
  334. throw std::logic_error("No underlying proxy transport for WebSocket transport");
  335. lower = transport;
  336. } else {
  337. auto transport = std::atomic_load(&mTcpTransport);
  338. if (!transport)
  339. throw std::logic_error("No underlying TCP transport for WebSocket transport");
  340. lower = transport;
  341. }
  342. if (!atomic_load(&mWsHandshake))
  343. atomic_store(&mWsHandshake, std::make_shared<WsHandshake>());
  344. auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
  345. auto shared_this = weak_this.lock();
  346. if (!shared_this)
  347. return;
  348. switch (transportState) {
  349. case State::Connected:
  350. if (state == WebSocket::State::Connecting) {
  351. PLOG_DEBUG << "WebSocket open";
  352. if (changeState(WebSocket::State::Open))
  353. triggerOpen();
  354. }
  355. break;
  356. case State::Failed:
  357. triggerError("WebSocket connection failed");
  358. remoteClose();
  359. break;
  360. case State::Disconnected:
  361. remoteClose();
  362. break;
  363. default:
  364. // Ignore
  365. break;
  366. }
  367. };
  368. auto transport = std::make_shared<WsTransport>(lower, mWsHandshake, config,
  369. weak_bind(&WebSocket::incoming, this, _1),
  370. stateChangeCallback);
  371. return emplaceTransport(this, &mWsTransport, std::move(transport));
  372. } catch (const std::exception &e) {
  373. PLOG_ERROR << e.what();
  374. remoteClose();
  375. throw std::runtime_error("WebSocket transport initialization failed");
  376. }
  377. }
  378. shared_ptr<TcpTransport> WebSocket::getTcpTransport() const {
  379. return std::atomic_load(&mTcpTransport);
  380. }
  381. shared_ptr<TlsTransport> WebSocket::getTlsTransport() const {
  382. return std::atomic_load(&mTlsTransport);
  383. }
  384. shared_ptr<WsTransport> WebSocket::getWsTransport() const {
  385. return std::atomic_load(&mWsTransport);
  386. }
  387. shared_ptr<WsHandshake> WebSocket::getWsHandshake() const {
  388. return std::atomic_load(&mWsHandshake);
  389. }
  390. void WebSocket::closeTransports() {
  391. PLOG_VERBOSE << "Closing transports";
  392. if (!changeState(State::Closed))
  393. return; // already closed
  394. // Pass the pointers to a thread, allowing to terminate a transport from its own thread
  395. auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr));
  396. auto tls = std::atomic_exchange(&mTlsTransport, decltype(mTlsTransport)(nullptr));
  397. auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr));
  398. if (ws)
  399. ws->onRecv(nullptr);
  400. if (tcp)
  401. tcp->onBufferedAmount(nullptr);
  402. using array = std::array<shared_ptr<Transport>, 3>;
  403. array transports{std::move(ws), std::move(tls), std::move(tcp)};
  404. for (const auto &t : transports)
  405. if (t)
  406. t->onStateChange(nullptr);
  407. TearDownProcessor::Instance().enqueue(
  408. [transports = std::move(transports), token = Init::Instance().token()]() mutable {
  409. for (const auto &t : transports) {
  410. if (t) {
  411. t->stop();
  412. break;
  413. }
  414. }
  415. for (auto &t : transports)
  416. t.reset();
  417. });
  418. triggerClosed();
  419. }
  420. void WebSocket::scheduleConnectionTimeout() {
  421. auto defaultTimeout = 30s;
  422. auto timeout = config.connectionTimeout.value_or(milliseconds(defaultTimeout));
  423. if (timeout > milliseconds::zero()) {
  424. ThreadPool::Instance().schedule(timeout, [weak_this = weak_from_this()]() {
  425. if (auto locked = weak_this.lock()) {
  426. if (locked->state == WebSocket::State::Connecting) {
  427. PLOG_WARNING << "WebSocket connection timed out";
  428. locked->triggerError("Connection timed out");
  429. locked->remoteClose();
  430. }
  431. }
  432. });
  433. }
  434. }
  435. } // namespace rtc::impl
  436. #endif