websocket.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. /**
  2. * Copyright (c) 2020-2021 Paul-Louis Ageneau
  3. *
  4. * This Source Code Form is subject to the terms of the Mozilla Public
  5. * License, v. 2.0. If a copy of the MPL was not distributed with this
  6. * file, You can obtain one at https://mozilla.org/MPL/2.0/.
  7. */
  8. #if RTC_ENABLE_WEBSOCKET
  9. #include "websocket.hpp"
  10. #include "common.hpp"
  11. #include "internals.hpp"
  12. #include "processor.hpp"
  13. #include "utils.hpp"
  14. #include "httpproxytransport.hpp"
  15. #include "tcptransport.hpp"
  16. #include "tlstransport.hpp"
  17. #include "verifiedtlstransport.hpp"
  18. #include "wstransport.hpp"
  19. #include <array>
  20. #include <chrono>
  21. #include <regex>
  22. #ifdef _WIN32
  23. #include <winsock2.h>
  24. #endif
  25. namespace rtc::impl {
  26. using namespace std::placeholders;
  27. using namespace std::chrono_literals;
  28. using std::chrono::milliseconds;
  29. const string PemBeginCertificateTag = "-----BEGIN CERTIFICATE-----";
  30. WebSocket::WebSocket(optional<Configuration> optConfig, certificate_ptr certificate)
  31. : config(optConfig ? std::move(*optConfig) : Configuration()),
  32. mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {
  33. PLOG_VERBOSE << "Creating WebSocket";
  34. if (certificate) {
  35. mCertificate = std::move(certificate);
  36. } else if (config.certificatePemFile && config.keyPemFile) {
  37. mCertificate = std::make_shared<Certificate>(
  38. config.certificatePemFile->find(PemBeginCertificateTag) != string::npos
  39. ? Certificate::FromString(*config.certificatePemFile, *config.keyPemFile)
  40. : Certificate::FromFile(*config.certificatePemFile, *config.keyPemFile,
  41. config.keyPemPass.value_or("")));
  42. } else if (config.certificatePemFile || config.keyPemFile) {
  43. throw std::invalid_argument(
  44. "Either none or both certificate and key PEM files must be specified");
  45. }
  46. mIsSecure = mCertificate != nullptr;
  47. if (config.proxyServer) {
  48. if (config.proxyServer->type == ProxyServer::Type::Socks5)
  49. throw std::invalid_argument(
  50. "Proxy server support for WebSocket is not implemented for Socks5");
  51. if (config.proxyServer->username || config.proxyServer->password) {
  52. PLOG_WARNING << "HTTP authentication support for proxy is not implemented";
  53. }
  54. }
  55. }
  56. WebSocket::~WebSocket() { PLOG_VERBOSE << "Destroying WebSocket"; }
  57. void WebSocket::open(const string &url) {
  58. PLOG_VERBOSE << "Opening WebSocket to URL: " << url;
  59. if (state != State::Closed)
  60. throw std::logic_error("WebSocket must be closed before opening");
  61. // Modified regex from RFC 3986, see https://www.rfc-editor.org/rfc/rfc3986.html#appendix-B
  62. static const char *rs =
  63. R"(^(([^:.@/?#]+):)?(/{0,2}((([^:@]*)(:([^@]*))?)@)?(([^:/?#]*)(:([^/?#]*))?))?([^?#]*)(\?([^#]*))?(#(.*))?)";
  64. static const std::regex r(rs, std::regex::extended);
  65. std::smatch m;
  66. if (!std::regex_match(url, m, r) || m[10].length() == 0)
  67. throw std::invalid_argument("Invalid WebSocket URL: " + url);
  68. string scheme = m[2];
  69. if (scheme.empty())
  70. scheme = "ws";
  71. if (scheme != "ws" && scheme != "wss")
  72. throw std::invalid_argument("Invalid WebSocket scheme: " + scheme);
  73. mIsSecure = (scheme != "ws");
  74. string username = utils::url_decode(m[6]);
  75. string password = utils::url_decode(m[8]);
  76. if (!username.empty() || !password.empty()) {
  77. PLOG_WARNING << "HTTP authentication support for WebSocket is not implemented";
  78. }
  79. string host;
  80. string hostname = m[10];
  81. string service = m[12];
  82. if (service.empty()) {
  83. service = mIsSecure ? "443" : "80";
  84. host = hostname;
  85. } else {
  86. host = hostname + ':' + service;
  87. }
  88. if (hostname.front() == '[' && hostname.back() == ']') {
  89. // IPv6 literal
  90. hostname.erase(hostname.begin());
  91. hostname.pop_back();
  92. } else {
  93. hostname = utils::url_decode(hostname);
  94. }
  95. string path = m[13];
  96. if (path.empty())
  97. path += '/';
  98. if (string query = m[15]; !query.empty())
  99. path += "?" + query;
  100. mHostname = hostname; // for TLS SNI and Proxy
  101. mService = service; // For proxy
  102. std::atomic_store(&mWsHandshake, std::make_shared<WsHandshake>(host, path, config.protocols));
  103. changeState(State::Connecting);
  104. if (config.proxyServer) {
  105. setTcpTransport(std::make_shared<TcpTransport>(
  106. config.proxyServer->hostname, std::to_string(config.proxyServer->port), nullptr));
  107. } else {
  108. setTcpTransport(std::make_shared<TcpTransport>(hostname, service, nullptr));
  109. }
  110. }
  111. void WebSocket::close() {
  112. auto s = state.load();
  113. if (s == State::Connecting || s == State::Open) {
  114. PLOG_VERBOSE << "Closing WebSocket";
  115. changeState(State::Closing);
  116. if (auto transport = std::atomic_load(&mWsTransport))
  117. transport->stop();
  118. else
  119. remoteClose();
  120. }
  121. }
  122. void WebSocket::remoteClose() {
  123. close();
  124. if (state.load() != State::Closed)
  125. closeTransports();
  126. }
  127. bool WebSocket::isOpen() const { return state == State::Open; }
  128. bool WebSocket::isClosed() const { return state == State::Closed; }
  129. size_t WebSocket::maxMessageSize() const {
  130. return config.maxMessageSize.value_or(DEFAULT_WS_MAX_MESSAGE_SIZE);
  131. }
  132. optional<message_variant> WebSocket::receive() {
  133. auto next = mRecvQueue.pop();
  134. return next ? std::make_optional(to_variant(std::move(**next))) : nullopt;
  135. }
  136. optional<message_variant> WebSocket::peek() {
  137. auto next = mRecvQueue.peek();
  138. return next ? std::make_optional(to_variant(std::move(**next))) : nullopt;
  139. }
  140. size_t WebSocket::availableAmount() const { return mRecvQueue.amount(); }
  141. bool WebSocket::changeState(State newState) { return state.exchange(newState) != newState; }
  142. bool WebSocket::outgoing(message_ptr message) {
  143. if (state != State::Open || !mWsTransport)
  144. throw std::runtime_error("WebSocket is not open");
  145. if (message->size() > maxMessageSize())
  146. throw std::runtime_error("Message size exceeds limit");
  147. return mWsTransport->send(message);
  148. }
  149. void WebSocket::incoming(message_ptr message) {
  150. if (!message) {
  151. remoteClose();
  152. return;
  153. }
  154. if (message->type == Message::String || message->type == Message::Binary) {
  155. mRecvQueue.push(message);
  156. triggerAvailable(mRecvQueue.size());
  157. }
  158. }
  159. // Helper for WebSocket::initXTransport methods: start and emplace the transport
  160. template <typename T>
  161. shared_ptr<T> emplaceTransport(WebSocket *ws, shared_ptr<T> *member, shared_ptr<T> transport) {
  162. std::atomic_store(member, transport);
  163. try {
  164. transport->start();
  165. } catch (...) {
  166. std::atomic_store(member, decltype(transport)(nullptr));
  167. transport->stop();
  168. throw;
  169. }
  170. if (ws->state == WebSocket::State::Closed) {
  171. std::atomic_store(member, decltype(transport)(nullptr));
  172. transport->stop();
  173. return nullptr;
  174. }
  175. return transport;
  176. }
  177. shared_ptr<TcpTransport> WebSocket::setTcpTransport(shared_ptr<TcpTransport> transport) {
  178. PLOG_VERBOSE << "Starting TCP transport";
  179. if (!transport)
  180. throw std::logic_error("TCP transport is null");
  181. using State = TcpTransport::State;
  182. try {
  183. if (std::atomic_load(&mTcpTransport))
  184. throw std::logic_error("TCP transport is already set");
  185. transport->onBufferedAmount(weak_bind(&WebSocket::triggerBufferedAmount, this, _1));
  186. transport->onStateChange([this, weak_this = weak_from_this()](State transportState) {
  187. auto shared_this = weak_this.lock();
  188. if (!shared_this)
  189. return;
  190. switch (transportState) {
  191. case State::Connected:
  192. if (config.proxyServer)
  193. initProxyTransport();
  194. else if (mIsSecure)
  195. initTlsTransport();
  196. else
  197. initWsTransport();
  198. break;
  199. case State::Failed:
  200. triggerError("TCP connection failed");
  201. remoteClose();
  202. break;
  203. case State::Disconnected:
  204. if(state == WebSocket::State::Connecting)
  205. remoteClose();
  206. break;
  207. default:
  208. // Ignore
  209. break;
  210. }
  211. });
  212. // WS transport sends a ping on read timeout
  213. auto pingInterval = config.pingInterval.value_or(10000ms);
  214. if (pingInterval > milliseconds::zero())
  215. transport->setReadTimeout(pingInterval);
  216. scheduleConnectionTimeout();
  217. return emplaceTransport(this, &mTcpTransport, std::move(transport));
  218. } catch (const std::exception &e) {
  219. PLOG_ERROR << e.what();
  220. remoteClose();
  221. throw std::runtime_error("TCP transport initialization failed");
  222. }
  223. }
  224. shared_ptr<HttpProxyTransport> WebSocket::initProxyTransport() {
  225. PLOG_VERBOSE << "Starting Tcp Proxy transport";
  226. using State = HttpProxyTransport::State;
  227. try {
  228. if (auto transport = std::atomic_load(&mProxyTransport))
  229. return transport;
  230. auto lower = std::atomic_load(&mTcpTransport);
  231. if (!lower)
  232. throw std::logic_error("No underlying TCP transport for Proxy transport");
  233. auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
  234. auto shared_this = weak_this.lock();
  235. if (!shared_this)
  236. return;
  237. switch (transportState) {
  238. case State::Connected:
  239. if (mIsSecure)
  240. initTlsTransport();
  241. else
  242. initWsTransport();
  243. break;
  244. case State::Failed:
  245. triggerError("Proxy connection failed");
  246. remoteClose();
  247. break;
  248. case State::Disconnected:
  249. if(state == WebSocket::State::Connecting)
  250. remoteClose();
  251. break;
  252. default:
  253. // Ignore
  254. break;
  255. }
  256. };
  257. auto transport = std::make_shared<HttpProxyTransport>(
  258. lower, mHostname.value(), mService.value(), stateChangeCallback);
  259. return emplaceTransport(this, &mProxyTransport, std::move(transport));
  260. } catch (const std::exception &e) {
  261. PLOG_ERROR << e.what();
  262. remoteClose();
  263. throw std::runtime_error("Tcp Proxy transport initialization failed");
  264. }
  265. }
  266. shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
  267. PLOG_VERBOSE << "Starting TLS transport";
  268. using State = TlsTransport::State;
  269. try {
  270. if (auto transport = std::atomic_load(&mTlsTransport))
  271. return transport;
  272. variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>> lower;
  273. if (config.proxyServer) {
  274. auto transport = std::atomic_load(&mProxyTransport);
  275. if (!transport)
  276. throw std::logic_error("No underlying proxy transport for TLS transport");
  277. lower = transport;
  278. } else {
  279. auto transport = std::atomic_load(&mTcpTransport);
  280. if (!transport)
  281. throw std::logic_error("No underlying TCP transport for TLS transport");
  282. lower = transport;
  283. }
  284. auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
  285. auto shared_this = weak_this.lock();
  286. if (!shared_this)
  287. return;
  288. switch (transportState) {
  289. case State::Connected:
  290. initWsTransport();
  291. break;
  292. case State::Failed:
  293. triggerError("TLS connection failed");
  294. remoteClose();
  295. break;
  296. case State::Disconnected:
  297. if(state == WebSocket::State::Connecting)
  298. remoteClose();
  299. break;
  300. default:
  301. // Ignore
  302. break;
  303. }
  304. };
  305. bool verify = mHostname.has_value() && !config.disableTlsVerification;
  306. #ifdef _WIN32
  307. if (std::exchange(verify, false)) {
  308. PLOG_WARNING << "TLS certificate verification with root CA is not supported on Windows";
  309. }
  310. #endif
  311. shared_ptr<TlsTransport> transport;
  312. if (verify)
  313. transport = std::make_shared<VerifiedTlsTransport>(lower, mHostname.value(),
  314. mCertificate, stateChangeCallback,
  315. config.caCertificatePemFile);
  316. else
  317. transport =
  318. std::make_shared<TlsTransport>(lower, mHostname, mCertificate, stateChangeCallback);
  319. return emplaceTransport(this, &mTlsTransport, std::move(transport));
  320. } catch (const std::exception &e) {
  321. PLOG_ERROR << e.what();
  322. remoteClose();
  323. throw std::runtime_error("TLS transport initialization failed");
  324. }
  325. }
  326. shared_ptr<WsTransport> WebSocket::initWsTransport() {
  327. PLOG_VERBOSE << "Starting WebSocket transport";
  328. using State = WsTransport::State;
  329. try {
  330. if (auto transport = std::atomic_load(&mWsTransport))
  331. return transport;
  332. variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>, shared_ptr<TlsTransport>>
  333. lower;
  334. if (mIsSecure) {
  335. auto transport = std::atomic_load(&mTlsTransport);
  336. if (!transport)
  337. throw std::logic_error("No underlying TLS transport for WebSocket transport");
  338. lower = transport;
  339. } else if (config.proxyServer) {
  340. auto transport = std::atomic_load(&mProxyTransport);
  341. if (!transport)
  342. throw std::logic_error("No underlying proxy transport for WebSocket transport");
  343. lower = transport;
  344. } else {
  345. auto transport = std::atomic_load(&mTcpTransport);
  346. if (!transport)
  347. throw std::logic_error("No underlying TCP transport for WebSocket transport");
  348. lower = transport;
  349. }
  350. if (!atomic_load(&mWsHandshake))
  351. atomic_store(&mWsHandshake, std::make_shared<WsHandshake>());
  352. auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
  353. auto shared_this = weak_this.lock();
  354. if (!shared_this)
  355. return;
  356. switch (transportState) {
  357. case State::Connected:
  358. if (state == WebSocket::State::Connecting) {
  359. PLOG_DEBUG << "WebSocket open";
  360. if (changeState(WebSocket::State::Open))
  361. triggerOpen();
  362. }
  363. break;
  364. case State::Failed:
  365. triggerError("WebSocket connection failed");
  366. remoteClose();
  367. break;
  368. case State::Disconnected:
  369. remoteClose();
  370. break;
  371. default:
  372. // Ignore
  373. break;
  374. }
  375. };
  376. auto transport = std::make_shared<WsTransport>(lower, mWsHandshake, config,
  377. weak_bind(&WebSocket::incoming, this, _1),
  378. stateChangeCallback);
  379. return emplaceTransport(this, &mWsTransport, std::move(transport));
  380. } catch (const std::exception &e) {
  381. PLOG_ERROR << e.what();
  382. remoteClose();
  383. throw std::runtime_error("WebSocket transport initialization failed");
  384. }
  385. }
  386. shared_ptr<TcpTransport> WebSocket::getTcpTransport() const {
  387. return std::atomic_load(&mTcpTransport);
  388. }
  389. shared_ptr<TlsTransport> WebSocket::getTlsTransport() const {
  390. return std::atomic_load(&mTlsTransport);
  391. }
  392. shared_ptr<WsTransport> WebSocket::getWsTransport() const {
  393. return std::atomic_load(&mWsTransport);
  394. }
  395. shared_ptr<WsHandshake> WebSocket::getWsHandshake() const {
  396. return std::atomic_load(&mWsHandshake);
  397. }
  398. void WebSocket::closeTransports() {
  399. PLOG_VERBOSE << "Closing transports";
  400. if (!changeState(State::Closed))
  401. return; // already closed
  402. // Pass the pointers to a thread, allowing to terminate a transport from its own thread
  403. auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr));
  404. auto tls = std::atomic_exchange(&mTlsTransport, decltype(mTlsTransport)(nullptr));
  405. auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr));
  406. if (ws)
  407. ws->onRecv(nullptr);
  408. if (tcp)
  409. tcp->onBufferedAmount(nullptr);
  410. using array = std::array<shared_ptr<Transport>, 3>;
  411. array transports{std::move(ws), std::move(tls), std::move(tcp)};
  412. for (const auto &t : transports)
  413. if (t)
  414. t->onStateChange(nullptr);
  415. TearDownProcessor::Instance().enqueue(
  416. [transports = std::move(transports), token = Init::Instance().token()]() mutable {
  417. for (const auto &t : transports) {
  418. if (t) {
  419. t->stop();
  420. break;
  421. }
  422. }
  423. for (auto &t : transports)
  424. t.reset();
  425. });
  426. triggerClosed();
  427. }
  428. void WebSocket::scheduleConnectionTimeout() {
  429. auto defaultTimeout = 30s;
  430. auto timeout = config.connectionTimeout.value_or(milliseconds(defaultTimeout));
  431. if (timeout > milliseconds::zero()) {
  432. ThreadPool::Instance().schedule(timeout, [weak_this = weak_from_this()]() {
  433. if (auto locked = weak_this.lock()) {
  434. if (locked->state == WebSocket::State::Connecting) {
  435. PLOG_WARNING << "WebSocket connection timed out";
  436. locked->triggerError("Connection timed out");
  437. locked->remoteClose();
  438. }
  439. }
  440. });
  441. }
  442. }
  443. } // namespace rtc::impl
  444. #endif