websocket.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. /**
  2. * Copyright (c) 2020-2021 Paul-Louis Ageneau
  3. *
  4. * This library is free software; you can redistribute it and/or
  5. * modify it under the terms of the GNU Lesser General Public
  6. * License as published by the Free Software Foundation; either
  7. * version 2.1 of the License, or (at your option) any later version.
  8. *
  9. * This library is distributed in the hope that it will be useful,
  10. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  12. * Lesser General Public License for more details.
  13. *
  14. * You should have received a copy of the GNU Lesser General Public
  15. * License along with this library; if not, write to the Free Software
  16. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  17. */
  18. #if RTC_ENABLE_WEBSOCKET
  19. #include "websocket.hpp"
  20. #include "common.hpp"
  21. #include "internals.hpp"
  22. #include "processor.hpp"
  23. #include "utils.hpp"
  24. #include "tcptransport.hpp"
  25. #include "tlstransport.hpp"
  26. #include "verifiedtlstransport.hpp"
  27. #include "wstransport.hpp"
  28. #include <array>
  29. #include <chrono>
  30. #include <regex>
  31. #ifdef _WIN32
  32. #include <winsock2.h>
  33. #endif
  34. namespace rtc::impl {
  35. using namespace std::placeholders;
  36. using namespace std::chrono_literals;
  37. WebSocket::WebSocket(optional<Configuration> optConfig, certificate_ptr certificate)
  38. : config(optConfig ? std::move(*optConfig) : Configuration()),
  39. mCertificate(std::move(certificate)), mIsSecure(mCertificate != nullptr),
  40. mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {
  41. PLOG_VERBOSE << "Creating WebSocket";
  42. }
  43. WebSocket::~WebSocket() { PLOG_VERBOSE << "Destroying WebSocket"; }
  44. void WebSocket::open(const string &url) {
  45. PLOG_VERBOSE << "Opening WebSocket to URL: " << url;
  46. if (state != State::Closed)
  47. throw std::logic_error("WebSocket must be closed before opening");
  48. if (config.proxyServer) {
  49. PLOG_WARNING << "Proxy server support for WebSocket is not implemented";
  50. }
  51. // Modified regex from RFC 3986, see https://www.rfc-editor.org/rfc/rfc3986.html#appendix-B
  52. static const char *rs =
  53. R"(^(([^:.@/?#]+):)?(/{0,2}((([^:@]*)(:([^@]*))?)@)?(([^:/?#]*)(:([^/?#]*))?))?([^?#]*)(\?([^#]*))?(#(.*))?)";
  54. static const std::regex r(rs, std::regex::extended);
  55. std::smatch m;
  56. if (!std::regex_match(url, m, r) || m[10].length() == 0)
  57. throw std::invalid_argument("Invalid WebSocket URL: " + url);
  58. string scheme = m[2];
  59. if (scheme.empty())
  60. scheme = "ws";
  61. if (scheme != "ws" && scheme != "wss")
  62. throw std::invalid_argument("Invalid WebSocket scheme: " + scheme);
  63. mIsSecure = (scheme != "ws");
  64. string username = utils::url_decode(m[6]);
  65. string password = utils::url_decode(m[8]);
  66. if (!username.empty() || !password.empty()) {
  67. PLOG_WARNING << "HTTP authentication support for WebSocket is not implemented";
  68. }
  69. string host;
  70. string hostname = m[10];
  71. string service = m[12];
  72. if (service.empty()) {
  73. service = mIsSecure ? "443" : "80";
  74. host = hostname;
  75. } else {
  76. host = hostname + ':' + service;
  77. }
  78. if (hostname.front() == '[' && hostname.back() == ']') {
  79. // IPv6 literal
  80. hostname.erase(hostname.begin());
  81. hostname.pop_back();
  82. } else {
  83. hostname = utils::url_decode(hostname);
  84. }
  85. string path = m[13];
  86. if (path.empty())
  87. path += '/';
  88. if (string query = m[15]; !query.empty())
  89. path += "?" + query;
  90. mHostname = hostname; // for TLS SNI
  91. std::atomic_store(&mWsHandshake, std::make_shared<WsHandshake>(host, path, config.protocols));
  92. changeState(State::Connecting);
  93. setTcpTransport(std::make_shared<TcpTransport>(hostname, service, nullptr));
  94. }
  95. void WebSocket::close() {
  96. auto s = state.load();
  97. if (s == State::Connecting || s == State::Open) {
  98. PLOG_VERBOSE << "Closing WebSocket";
  99. changeState(State::Closing);
  100. if (auto transport = std::atomic_load(&mWsTransport))
  101. transport->close();
  102. else
  103. remoteClose();
  104. }
  105. }
  106. void WebSocket::remoteClose() {
  107. if (state != State::Closed) {
  108. close();
  109. closeTransports();
  110. }
  111. }
  112. bool WebSocket::isOpen() const { return state == State::Open; }
  113. bool WebSocket::isClosed() const { return state == State::Closed; }
  114. size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
  115. optional<message_variant> WebSocket::receive() {
  116. while (auto next = mRecvQueue.tryPop()) {
  117. message_ptr message = *next;
  118. if (message->type != Message::Control)
  119. return to_variant(std::move(*message));
  120. }
  121. return nullopt;
  122. }
  123. optional<message_variant> WebSocket::peek() {
  124. while (auto next = mRecvQueue.peek()) {
  125. message_ptr message = *next;
  126. if (message->type != Message::Control)
  127. return to_variant(std::move(*message));
  128. mRecvQueue.tryPop();
  129. }
  130. return nullopt;
  131. }
  132. size_t WebSocket::availableAmount() const { return mRecvQueue.amount(); }
  133. bool WebSocket::changeState(State newState) { return state.exchange(newState) != newState; }
  134. bool WebSocket::outgoing(message_ptr message) {
  135. if (state != State::Open || !mWsTransport)
  136. throw std::runtime_error("WebSocket is not open");
  137. if (message->size() > maxMessageSize())
  138. throw std::runtime_error("Message size exceeds limit");
  139. return mWsTransport->send(message);
  140. }
  141. void WebSocket::incoming(message_ptr message) {
  142. if (!message) {
  143. remoteClose();
  144. return;
  145. }
  146. if (message->type == Message::String || message->type == Message::Binary) {
  147. mRecvQueue.push(message);
  148. triggerAvailable(mRecvQueue.size());
  149. }
  150. }
  151. // Helper for WebSocket::initXTransport methods: start and emplace the transport
  152. template <typename T>
  153. shared_ptr<T> emplaceTransport(WebSocket *ws, shared_ptr<T> *member, shared_ptr<T> transport) {
  154. std::atomic_store(member, transport);
  155. try {
  156. transport->start();
  157. } catch (...) {
  158. std::atomic_store(member, decltype(transport)(nullptr));
  159. transport->stop();
  160. throw;
  161. }
  162. if (ws->state == WebSocket::State::Closed) {
  163. std::atomic_store(member, decltype(transport)(nullptr));
  164. transport->stop();
  165. return nullptr;
  166. }
  167. return transport;
  168. }
  169. shared_ptr<TcpTransport> WebSocket::setTcpTransport(shared_ptr<TcpTransport> transport) {
  170. PLOG_VERBOSE << "Starting TCP transport";
  171. if (!transport)
  172. throw std::logic_error("TCP transport is null");
  173. using State = TcpTransport::State;
  174. try {
  175. if (std::atomic_load(&mTcpTransport))
  176. throw std::logic_error("TCP transport is already set");
  177. transport->onStateChange([this, weak_this = weak_from_this()](State transportState) {
  178. auto shared_this = weak_this.lock();
  179. if (!shared_this)
  180. return;
  181. switch (transportState) {
  182. case State::Connected:
  183. if (mIsSecure)
  184. initTlsTransport();
  185. else
  186. initWsTransport();
  187. break;
  188. case State::Failed:
  189. triggerError("TCP connection failed");
  190. remoteClose();
  191. break;
  192. case State::Disconnected:
  193. remoteClose();
  194. break;
  195. default:
  196. // Ignore
  197. break;
  198. }
  199. });
  200. // WS transport sends a ping on read timeout
  201. auto pingInterval = config.pingInterval.value_or(10000ms);
  202. if (pingInterval > std::chrono::milliseconds::zero())
  203. transport->setReadTimeout(pingInterval);
  204. return emplaceTransport(this, &mTcpTransport, std::move(transport));
  205. } catch (const std::exception &e) {
  206. PLOG_ERROR << e.what();
  207. remoteClose();
  208. throw std::runtime_error("TCP transport initialization failed");
  209. }
  210. }
  211. shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
  212. PLOG_VERBOSE << "Starting TLS transport";
  213. using State = TlsTransport::State;
  214. try {
  215. if (auto transport = std::atomic_load(&mTlsTransport))
  216. return transport;
  217. auto lower = std::atomic_load(&mTcpTransport);
  218. if (!lower)
  219. throw std::logic_error("No underlying TCP transport for TLS transport");
  220. auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
  221. auto shared_this = weak_this.lock();
  222. if (!shared_this)
  223. return;
  224. switch (transportState) {
  225. case State::Connected:
  226. initWsTransport();
  227. break;
  228. case State::Failed:
  229. triggerError("TLS connection failed");
  230. remoteClose();
  231. break;
  232. case State::Disconnected:
  233. remoteClose();
  234. break;
  235. default:
  236. // Ignore
  237. break;
  238. }
  239. };
  240. bool verify = mHostname.has_value() && !config.disableTlsVerification;
  241. #ifdef _WIN32
  242. if (std::exchange(verify, false)) {
  243. PLOG_WARNING << "TLS certificate verification with root CA is not supported on Windows";
  244. }
  245. #endif
  246. shared_ptr<TlsTransport> transport;
  247. if (verify)
  248. transport = std::make_shared<VerifiedTlsTransport>(lower, mHostname.value(),
  249. mCertificate, stateChangeCallback);
  250. else
  251. transport =
  252. std::make_shared<TlsTransport>(lower, mHostname, mCertificate, stateChangeCallback);
  253. return emplaceTransport(this, &mTlsTransport, std::move(transport));
  254. } catch (const std::exception &e) {
  255. PLOG_ERROR << e.what();
  256. remoteClose();
  257. throw std::runtime_error("TLS transport initialization failed");
  258. }
  259. }
  260. shared_ptr<WsTransport> WebSocket::initWsTransport() {
  261. PLOG_VERBOSE << "Starting WebSocket transport";
  262. using State = WsTransport::State;
  263. try {
  264. if (auto transport = std::atomic_load(&mWsTransport))
  265. return transport;
  266. variant<shared_ptr<TcpTransport>, shared_ptr<TlsTransport>> lower;
  267. if (mIsSecure) {
  268. auto transport = std::atomic_load(&mTlsTransport);
  269. if (!transport)
  270. throw std::logic_error("No underlying TLS transport for WebSocket transport");
  271. lower = transport;
  272. } else {
  273. auto transport = std::atomic_load(&mTcpTransport);
  274. if (!transport)
  275. throw std::logic_error("No underlying TCP transport for WebSocket transport");
  276. lower = transport;
  277. }
  278. if (!atomic_load(&mWsHandshake))
  279. atomic_store(&mWsHandshake, std::make_shared<WsHandshake>());
  280. auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
  281. auto shared_this = weak_this.lock();
  282. if (!shared_this)
  283. return;
  284. switch (transportState) {
  285. case State::Connected:
  286. if (state == WebSocket::State::Connecting) {
  287. PLOG_DEBUG << "WebSocket open";
  288. if (changeState(WebSocket::State::Open))
  289. triggerOpen();
  290. }
  291. break;
  292. case State::Failed:
  293. triggerError("WebSocket connection failed");
  294. remoteClose();
  295. break;
  296. case State::Disconnected:
  297. remoteClose();
  298. break;
  299. default:
  300. // Ignore
  301. break;
  302. }
  303. };
  304. auto maxOutstandingPings = config.maxOutstandingPings.value_or(0);
  305. auto transport = std::make_shared<WsTransport>(lower, mWsHandshake, maxOutstandingPings,
  306. weak_bind(&WebSocket::incoming, this, _1),
  307. stateChangeCallback);
  308. return emplaceTransport(this, &mWsTransport, std::move(transport));
  309. } catch (const std::exception &e) {
  310. PLOG_ERROR << e.what();
  311. remoteClose();
  312. throw std::runtime_error("WebSocket transport initialization failed");
  313. }
  314. }
  315. shared_ptr<TcpTransport> WebSocket::getTcpTransport() const {
  316. return std::atomic_load(&mTcpTransport);
  317. }
  318. shared_ptr<TlsTransport> WebSocket::getTlsTransport() const {
  319. return std::atomic_load(&mTlsTransport);
  320. }
  321. shared_ptr<WsTransport> WebSocket::getWsTransport() const {
  322. return std::atomic_load(&mWsTransport);
  323. }
  324. shared_ptr<WsHandshake> WebSocket::getWsHandshake() const {
  325. return std::atomic_load(&mWsHandshake);
  326. }
  327. void WebSocket::closeTransports() {
  328. PLOG_VERBOSE << "Closing transports";
  329. if (!changeState(State::Closed))
  330. return; // already closed
  331. // Pass the pointers to a thread, allowing to terminate a transport from its own thread
  332. auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr));
  333. auto tls = std::atomic_exchange(&mTlsTransport, decltype(mTlsTransport)(nullptr));
  334. auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr));
  335. if (ws)
  336. ws->onRecv(nullptr);
  337. using array = std::array<shared_ptr<Transport>, 3>;
  338. array transports{std::move(ws), std::move(tls), std::move(tcp)};
  339. for (const auto &t : transports)
  340. if (t)
  341. t->onStateChange(nullptr);
  342. TearDownProcessor::Instance().enqueue(
  343. [transports = std::move(transports), token = Init::Instance().token()]() mutable {
  344. for (const auto &t : transports)
  345. if (t)
  346. t->stop();
  347. for (auto &t : transports)
  348. t.reset();
  349. });
  350. triggerClosed();
  351. }
  352. } // namespace rtc::impl
  353. #endif