wstransport.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. /**
  2. * Copyright (c) 2020 Paul-Louis Ageneau
  3. *
  4. * This library is free software; you can redistribute it and/or
  5. * modify it under the terms of the GNU Lesser General Public
  6. * License as published by the Free Software Foundation; either
  7. * version 2.1 of the License, or (at your option) any later version.
  8. *
  9. * This library is distributed in the hope that it will be useful,
  10. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  12. * Lesser General Public License for more details.
  13. *
  14. * You should have received a copy of the GNU Lesser General Public
  15. * License along with this library; if not, write to the Free Software
  16. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  17. */
  18. #include "wstransport.hpp"
  19. #include "tcptransport.hpp"
  20. #include "tlstransport.hpp"
  21. #include "base64.hpp"
  22. #if RTC_ENABLE_WEBSOCKET
  23. #include <chrono>
  24. #include <list>
  25. #include <map>
  26. #include <random>
  27. #include <regex>
  28. #ifdef _WIN32
  29. #include <winsock2.h>
  30. #else
  31. #include <arpa/inet.h>
  32. #endif
  33. #ifndef htonll
  34. #define htonll(x) \
  35. ((uint64_t)htonl(((uint64_t)(x)&0xFFFFFFFF) << 32) | (uint64_t)htonl((uint64_t)(x) >> 32))
  36. #endif
  37. #ifndef ntohll
  38. #define ntohll(x) htonll(x)
  39. #endif
  40. namespace rtc {
  41. using namespace std::chrono;
  42. using std::to_integer;
  43. using std::to_string;
  44. using random_bytes_engine =
  45. std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned short>;
  46. WsTransport::WsTransport(std::shared_ptr<Transport> lower, string host, string path,
  47. message_callback recvCallback, state_callback stateCallback)
  48. : Transport(lower, std::move(stateCallback)), mHost(std::move(host)), mPath(std::move(path)) {
  49. onRecv(recvCallback);
  50. PLOG_DEBUG << "Initializing WebSocket transport";
  51. registerIncoming();
  52. sendHttpRequest();
  53. }
  54. WsTransport::~WsTransport() { stop(); }
  55. bool WsTransport::stop() {
  56. if (!Transport::stop())
  57. return false;
  58. close();
  59. return true;
  60. }
  61. bool WsTransport::send(message_ptr message) {
  62. if (!message || state() != State::Connected)
  63. return false;
  64. PLOG_VERBOSE << "Send size=" << message->size();
  65. return sendFrame({message->type == Message::String ? TEXT_FRAME : BINARY_FRAME, message->data(),
  66. message->size(), true, true});
  67. }
  68. void WsTransport::incoming(message_ptr message) {
  69. auto s = state();
  70. if (s != State::Connecting && s != State::Connected)
  71. return; // Drop
  72. if (message) {
  73. PLOG_VERBOSE << "Incoming size=" << message->size();
  74. if (message->size() == 0) {
  75. // TCP is idle, send a ping
  76. PLOG_DEBUG << "WebSocket sending ping";
  77. uint32_t dummy = 0;
  78. sendFrame({PING, reinterpret_cast<byte *>(&dummy), 4, true, true});
  79. return;
  80. }
  81. mBuffer.insert(mBuffer.end(), message->begin(), message->end());
  82. try {
  83. if (state() == State::Connecting) {
  84. if (size_t len = readHttpResponse(mBuffer.data(), mBuffer.size())) {
  85. PLOG_INFO << "WebSocket open";
  86. changeState(State::Connected);
  87. mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
  88. }
  89. }
  90. if (state() == State::Connected) {
  91. Frame frame;
  92. while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
  93. recvFrame(frame);
  94. mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
  95. }
  96. }
  97. return;
  98. } catch (const std::exception &e) {
  99. PLOG_ERROR << e.what();
  100. }
  101. }
  102. if (state() == State::Connected) {
  103. PLOG_INFO << "WebSocket disconnected";
  104. changeState(State::Disconnected);
  105. recv(nullptr);
  106. } else {
  107. PLOG_ERROR << "WebSocket handshake failed";
  108. changeState(State::Failed);
  109. }
  110. }
  111. void WsTransport::close() {
  112. if (state() == State::Connected) {
  113. sendFrame({CLOSE, NULL, 0, true, true});
  114. PLOG_INFO << "WebSocket closing";
  115. changeState(State::Disconnected);
  116. }
  117. }
  118. bool WsTransport::sendHttpRequest() {
  119. changeState(State::Connecting);
  120. auto seed = static_cast<unsigned int>(system_clock::now().time_since_epoch().count());
  121. random_bytes_engine generator(seed);
  122. binary key(16);
  123. auto k = reinterpret_cast<uint8_t *>(key.data());
  124. std::generate(k, k + key.size(), [&]() { return uint8_t(generator()); });
  125. const string request = "GET " + mPath +
  126. " HTTP/1.1\r\n"
  127. "Host: " +
  128. mHost +
  129. "\r\n"
  130. "Connection: Upgrade\r\n"
  131. "Upgrade: websocket\r\n"
  132. "Sec-WebSocket-Version: 13\r\n"
  133. "Sec-WebSocket-Key: " +
  134. to_base64(key) +
  135. "\r\n"
  136. "\r\n";
  137. auto data = reinterpret_cast<const byte *>(request.data());
  138. auto size = request.size();
  139. return outgoing(make_message(data, data + size));
  140. }
  141. size_t WsTransport::readHttpResponse(const byte *buffer, size_t size) {
  142. std::list<string> lines;
  143. auto begin = reinterpret_cast<const char *>(buffer);
  144. auto end = begin + size;
  145. auto cur = begin;
  146. while (true) {
  147. auto last = cur;
  148. cur = std::find(cur, end, '\n');
  149. if (cur == end)
  150. return 0;
  151. string line(last, cur != begin && *std::prev(cur) == '\r' ? std::prev(cur++) : cur++);
  152. if (line.empty())
  153. break;
  154. lines.emplace_back(std::move(line));
  155. }
  156. size_t length = cur - begin;
  157. if (lines.empty())
  158. throw std::runtime_error("Invalid HTTP response for WebSocket");
  159. string status = std::move(lines.front());
  160. lines.pop_front();
  161. std::istringstream ss(status);
  162. string protocol;
  163. unsigned int code = 0;
  164. ss >> protocol >> code;
  165. PLOG_DEBUG << "WebSocket response code: " << code;
  166. if (code != 101)
  167. throw std::runtime_error("Unexpected response code for WebSocket: " + to_string(code));
  168. std::multimap<string, string> headers;
  169. for (const auto &line : lines) {
  170. if (size_t pos = line.find_first_of(':'); pos != string::npos) {
  171. string key = line.substr(0, pos);
  172. string value = line.substr(line.find_first_not_of(' ', pos + 1));
  173. std::transform(key.begin(), key.end(), key.begin(),
  174. [](char c) { return std::tolower(c); });
  175. headers.emplace(std::move(key), std::move(value));
  176. } else {
  177. headers.emplace(line, "");
  178. }
  179. }
  180. auto h = headers.find("upgrade");
  181. if (h == headers.end() || h->second != "websocket")
  182. throw std::runtime_error("WebSocket update header missing or mismatching");
  183. h = headers.find("sec-websocket-accept");
  184. if (h == headers.end())
  185. throw std::runtime_error("WebSocket accept header missing");
  186. // TODO: Verify Sec-WebSocket-Accept
  187. return length;
  188. }
  189. // http://tools.ietf.org/html/rfc6455#section-5.2 Base Framing Protocol
  190. //
  191. // 0 1 2 3
  192. // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
  193. // +-+-+-+-+-------+-+-------------+-------------------------------+
  194. // |F|R|R|R| opcode|M| Payload len | Extended payload length |
  195. // |I|S|S|S| (4) |A| (7) | (16/64) |
  196. // |N|V|V|V| |S| | (if payload len==126/127) |
  197. // | |1|2|3| |K| | |
  198. // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
  199. // | Extended payload length continued, if payload len == 127 |
  200. // + - - - - - - - - - - - - - - - +-------------------------------+
  201. // | | Masking-key, if MASK set to 1 |
  202. // +-------------------------------+-------------------------------+
  203. // | Masking-key (continued) | Payload Data |
  204. // +-------------------------------+ - - - - - - - - - - - - - - - +
  205. // : Payload Data continued ... :
  206. // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
  207. // | Payload Data continued ... |
  208. // +---------------------------------------------------------------+
  209. size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) {
  210. const byte *end = buffer + size;
  211. if (end - buffer < 2)
  212. return 0;
  213. byte *cur = buffer;
  214. auto b1 = to_integer<uint8_t>(*cur++);
  215. auto b2 = to_integer<uint8_t>(*cur++);
  216. frame.fin = (b1 & 0x80) != 0;
  217. frame.mask = (b2 & 0x80) != 0;
  218. frame.opcode = static_cast<Opcode>(b1 & 0x0F);
  219. frame.length = b2 & 0x7F;
  220. if (frame.length == 0x7E) {
  221. if (end - cur < 2)
  222. return 0;
  223. frame.length = ntohs(*reinterpret_cast<const uint16_t *>(cur));
  224. cur += 2;
  225. } else if (frame.length == 0x7F) {
  226. if (end - cur < 8)
  227. return 0;
  228. frame.length = ntohll(*reinterpret_cast<const uint64_t *>(cur));
  229. cur += 8;
  230. }
  231. const byte *maskingKey = nullptr;
  232. if (frame.mask) {
  233. if (end - cur < 4)
  234. return 0;
  235. maskingKey = cur;
  236. cur += 4;
  237. }
  238. if (size_t(end - cur) < frame.length)
  239. return 0;
  240. frame.payload = cur;
  241. if (maskingKey)
  242. for (size_t i = 0; i < frame.length; ++i)
  243. frame.payload[i] ^= maskingKey[i % 4];
  244. cur += frame.length;
  245. return size_t(cur - buffer);
  246. }
  247. void WsTransport::recvFrame(const Frame &frame) {
  248. PLOG_DEBUG << "WebSocket received frame: opcode=" << int(frame.opcode)
  249. << ", length=" << frame.length;
  250. switch (frame.opcode) {
  251. case TEXT_FRAME:
  252. case BINARY_FRAME: {
  253. if (!mPartial.empty()) {
  254. PLOG_WARNING << "WebSocket unfinished message: type="
  255. << (mPartialOpcode == TEXT_FRAME ? "text" : "binary")
  256. << ", length=" << mPartial.size();
  257. auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
  258. recv(make_message(mPartial.begin(), mPartial.end(), type));
  259. mPartial.clear();
  260. }
  261. mPartialOpcode = frame.opcode;
  262. if (frame.fin) {
  263. PLOG_DEBUG << "WebSocket finished message: type="
  264. << (frame.opcode == TEXT_FRAME ? "text" : "binary")
  265. << ", length=" << frame.length;
  266. auto type = frame.opcode == TEXT_FRAME ? Message::String : Message::Binary;
  267. recv(make_message(frame.payload, frame.payload + frame.length, type));
  268. } else {
  269. mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
  270. }
  271. break;
  272. }
  273. case CONTINUATION: {
  274. mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
  275. if (frame.fin) {
  276. PLOG_DEBUG << "WebSocket finished message: type="
  277. << (frame.opcode == TEXT_FRAME ? "text" : "binary")
  278. << ", length=" << mPartial.size();
  279. auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
  280. recv(make_message(mPartial.begin(), mPartial.end(), type));
  281. mPartial.clear();
  282. }
  283. break;
  284. }
  285. case PING: {
  286. PLOG_DEBUG << "WebSocket received ping, sending pong";
  287. sendFrame({PONG, frame.payload, frame.length, true, true});
  288. break;
  289. }
  290. case PONG: {
  291. PLOG_DEBUG << "WebSocket received pong";
  292. break;
  293. }
  294. case CLOSE: {
  295. close();
  296. PLOG_INFO << "WebSocket closed";
  297. changeState(State::Disconnected);
  298. break;
  299. }
  300. default: {
  301. close();
  302. throw std::invalid_argument("Unknown WebSocket opcode: " + to_string(frame.opcode));
  303. }
  304. }
  305. }
  306. bool WsTransport::sendFrame(const Frame &frame) {
  307. PLOG_DEBUG << "WebSocket sending frame: opcode=" << int(frame.opcode)
  308. << ", length=" << frame.length;
  309. byte buffer[14];
  310. byte *cur = buffer;
  311. *cur++ = byte((frame.opcode & 0x0F) | (frame.fin ? 0x80 : 0));
  312. if (frame.length < 0x7E) {
  313. *cur++ = byte((frame.length & 0x7F) | (frame.mask ? 0x80 : 0));
  314. } else if (frame.length <= 0xFFFF) {
  315. *cur++ = byte(0x7E | (frame.mask ? 0x80 : 0));
  316. *reinterpret_cast<uint16_t *>(cur) = htons(uint16_t(frame.length));
  317. cur += 2;
  318. } else {
  319. *cur++ = byte(0x7F | (frame.mask ? 0x80 : 0));
  320. *reinterpret_cast<uint64_t *>(cur) = htonll(uint64_t(frame.length));
  321. cur += 8;
  322. }
  323. if (frame.mask) {
  324. auto seed = static_cast<unsigned int>(system_clock::now().time_since_epoch().count());
  325. random_bytes_engine generator(seed);
  326. byte *maskingKey = reinterpret_cast<byte *>(cur);
  327. auto u = reinterpret_cast<uint8_t *>(maskingKey);
  328. std::generate(u, u + 4, [&]() { return uint8_t(generator()); });
  329. cur += 4;
  330. for (size_t i = 0; i < frame.length; ++i)
  331. frame.payload[i] ^= maskingKey[i % 4];
  332. }
  333. outgoing(make_message(buffer, cur)); // header
  334. return outgoing(make_message(frame.payload, frame.payload + frame.length)); // payload
  335. }
  336. } // namespace rtc
  337. #endif