tcptransport.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. /**
  2. * Copyright (c) 2020 Paul-Louis Ageneau
  3. *
  4. * This Source Code Form is subject to the terms of the Mozilla Public
  5. * License, v. 2.0. If a copy of the MPL was not distributed with this
  6. * file, You can obtain one at https://mozilla.org/MPL/2.0/.
  7. */
  8. #include "tcptransport.hpp"
  9. #include "internals.hpp"
  10. #include "threadpool.hpp"
  11. #if RTC_ENABLE_WEBSOCKET
  12. #ifndef _WIN32
  13. #include <fcntl.h>
  14. #include <unistd.h>
  15. #endif
  16. #include <chrono>
  17. namespace rtc::impl {
  18. using namespace std::placeholders;
  19. using namespace std::chrono_literals;
  20. using std::chrono::duration_cast;
  21. using std::chrono::milliseconds;
  22. namespace {
  23. bool unmap_inet6_v4mapped(struct sockaddr *sa, socklen_t *len) {
  24. if (sa->sa_family != AF_INET6)
  25. return false;
  26. const auto *sin6 = reinterpret_cast<struct sockaddr_in6 *>(sa);
  27. if (!IN6_IS_ADDR_V4MAPPED(&sin6->sin6_addr))
  28. return false;
  29. struct sockaddr_in6 copy = *sin6;
  30. sin6 = &copy;
  31. auto *sin = reinterpret_cast<struct sockaddr_in *>(sa);
  32. std::memset(sin, 0, sizeof(*sin));
  33. sin->sin_family = AF_INET;
  34. sin->sin_port = sin6->sin6_port;
  35. std::memcpy(&sin->sin_addr, reinterpret_cast<const uint8_t *>(&sin6->sin6_addr) + 12, 4);
  36. *len = sizeof(*sin);
  37. return true;
  38. }
  39. }
  40. TcpTransport::TcpTransport(string hostname, string service, state_callback callback)
  41. : Transport(nullptr, std::move(callback)), mIsActive(true), mHostname(std::move(hostname)),
  42. mService(std::move(service)), mSock(INVALID_SOCKET) {
  43. PLOG_DEBUG << "Initializing TCP transport";
  44. }
  45. TcpTransport::TcpTransport(socket_t sock, state_callback callback)
  46. : Transport(nullptr, std::move(callback)), mIsActive(false), mSock(sock) {
  47. PLOG_DEBUG << "Initializing TCP transport with socket";
  48. // Configure socket
  49. configureSocket();
  50. // Retrieve hostname and service
  51. struct sockaddr_storage addr;
  52. socklen_t addrlen = sizeof(addr);
  53. if (::getpeername(mSock, reinterpret_cast<struct sockaddr *>(&addr), &addrlen) < 0)
  54. throw std::runtime_error("getpeername failed");
  55. unmap_inet6_v4mapped(reinterpret_cast<struct sockaddr *>(&addr), &addrlen);
  56. char node[MAX_NUMERICNODE_LEN];
  57. char serv[MAX_NUMERICSERV_LEN];
  58. if (::getnameinfo(reinterpret_cast<struct sockaddr *>(&addr), addrlen, node,
  59. MAX_NUMERICNODE_LEN, serv, MAX_NUMERICSERV_LEN,
  60. NI_NUMERICHOST | NI_NUMERICSERV) != 0)
  61. throw std::runtime_error("getnameinfo failed");
  62. mHostname = node;
  63. mService = serv;
  64. }
  65. TcpTransport::~TcpTransport() { close(); }
  66. void TcpTransport::onBufferedAmount(amount_callback callback) {
  67. mBufferedAmountCallback = std::move(callback);
  68. }
  69. void TcpTransport::setReadTimeout(std::chrono::milliseconds readTimeout) {
  70. mReadTimeout = readTimeout;
  71. }
  72. void TcpTransport::start() {
  73. if (mSock == INVALID_SOCKET) {
  74. connect();
  75. } else {
  76. changeState(State::Connected);
  77. setPoll(PollService::Direction::In);
  78. }
  79. }
  80. bool TcpTransport::send(message_ptr message) {
  81. std::lock_guard lock(mSendMutex);
  82. if (state() != State::Connected)
  83. throw std::runtime_error("Connection is not open");
  84. if (!message || message->size() == 0)
  85. return trySendQueue();
  86. PLOG_VERBOSE << "Send size=" << message->size();
  87. return outgoing(message);
  88. }
  89. void TcpTransport::incoming(message_ptr message) {
  90. if (!message)
  91. return;
  92. PLOG_VERBOSE << "Incoming size=" << message->size();
  93. recv(message);
  94. }
  95. bool TcpTransport::outgoing(message_ptr message) {
  96. // mSendMutex must be locked
  97. // Flush the queue, and if nothing is pending, try to send directly
  98. if (trySendQueue() && trySendMessage(message))
  99. return true;
  100. mSendQueue.push(message);
  101. updateBufferedAmount(ptrdiff_t(message->size()));
  102. setPoll(PollService::Direction::Both);
  103. return false;
  104. }
  105. bool TcpTransport::isActive() const { return mIsActive; }
  106. string TcpTransport::remoteAddress() const { return mHostname + ':' + mService; }
  107. void TcpTransport::connect() {
  108. if (state() == State::Connecting)
  109. throw std::logic_error("TCP connection is already in progress");
  110. if (state() == State::Connected)
  111. throw std::logic_error("TCP is already connected");
  112. PLOG_DEBUG << "Connecting to " << mHostname << ":" << mService;
  113. changeState(State::Connecting);
  114. ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::resolve, this));
  115. }
  116. void TcpTransport::resolve() {
  117. std::lock_guard lock(mSendMutex);
  118. mResolved.clear();
  119. if (state() != State::Connecting)
  120. return; // Cancelled
  121. try {
  122. PLOG_DEBUG << "Resolving " << mHostname << ":" << mService;
  123. struct addrinfo hints = {};
  124. hints.ai_family = AF_UNSPEC;
  125. hints.ai_socktype = SOCK_STREAM;
  126. hints.ai_protocol = IPPROTO_TCP;
  127. hints.ai_flags = AI_ADDRCONFIG;
  128. struct addrinfo *result = nullptr;
  129. if (getaddrinfo(mHostname.c_str(), mService.c_str(), &hints, &result))
  130. throw std::runtime_error("Resolution failed for \"" + mHostname + ":" + mService +
  131. "\"");
  132. try {
  133. struct addrinfo *ai = result;
  134. while (ai) {
  135. struct sockaddr_storage addr;
  136. std::memcpy(&addr, ai->ai_addr, ai->ai_addrlen);
  137. mResolved.emplace_back(addr, socklen_t(ai->ai_addrlen));
  138. ai = ai->ai_next;
  139. }
  140. } catch (...) {
  141. freeaddrinfo(result);
  142. throw;
  143. }
  144. freeaddrinfo(result);
  145. } catch (const std::exception &e) {
  146. PLOG_WARNING << e.what();
  147. changeState(State::Failed);
  148. return;
  149. }
  150. ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
  151. }
  152. void TcpTransport::attempt() {
  153. try {
  154. std::lock_guard lock(mSendMutex);
  155. if (state() != State::Connecting)
  156. return; // Cancelled
  157. if (mSock != INVALID_SOCKET) {
  158. ::closesocket(mSock);
  159. mSock = INVALID_SOCKET;
  160. }
  161. if (mResolved.empty()) {
  162. PLOG_WARNING << "Connection to " << mHostname << ":" << mService << " failed";
  163. changeState(State::Failed);
  164. return;
  165. }
  166. auto [addr, addrlen] = mResolved.front();
  167. mResolved.pop_front();
  168. createSocket(reinterpret_cast<const struct sockaddr *>(&addr), addrlen);
  169. } catch (const std::runtime_error &e) {
  170. PLOG_DEBUG << e.what();
  171. ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
  172. return;
  173. }
  174. const auto timeout = 10s;
  175. PollService::Instance().add(mSock, {PollService::Direction::Out, timeout,
  176. weak_bind(&TcpTransport::processConnect, this, _1)});
  177. }
  178. void TcpTransport::createSocket(const struct sockaddr *addr, socklen_t addrlen) {
  179. try {
  180. char node[MAX_NUMERICNODE_LEN];
  181. char serv[MAX_NUMERICSERV_LEN];
  182. if (getnameinfo(addr, addrlen, node, MAX_NUMERICNODE_LEN, serv, MAX_NUMERICSERV_LEN,
  183. NI_NUMERICHOST | NI_NUMERICSERV) == 0) {
  184. PLOG_DEBUG << "Trying address " << node << ":" << serv;
  185. }
  186. PLOG_VERBOSE << "Creating TCP socket";
  187. // Create socket
  188. mSock = ::socket(addr->sa_family, SOCK_STREAM, IPPROTO_TCP);
  189. if (mSock == INVALID_SOCKET)
  190. throw std::runtime_error("TCP socket creation failed");
  191. // Configure socket
  192. configureSocket();
  193. // Initiate connection
  194. int ret = ::connect(mSock, addr, addrlen);
  195. if (ret < 0 && sockerrno != SEINPROGRESS && sockerrno != SEWOULDBLOCK) {
  196. std::ostringstream msg;
  197. msg << "TCP connection to " << node << ":" << serv << " failed, errno=" << sockerrno;
  198. throw std::runtime_error(msg.str());
  199. }
  200. } catch (...) {
  201. if (mSock != INVALID_SOCKET) {
  202. ::closesocket(mSock);
  203. mSock = INVALID_SOCKET;
  204. }
  205. throw;
  206. }
  207. }
  208. void TcpTransport::configureSocket() {
  209. // Set non-blocking
  210. ctl_t nbio = 1;
  211. if (::ioctlsocket(mSock, FIONBIO, &nbio) < 0)
  212. throw std::runtime_error("Failed to set socket non-blocking mode");
  213. // Disable the Nagle algorithm
  214. int nodelay = 1;
  215. ::setsockopt(mSock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<const char *>(&nodelay),
  216. sizeof(nodelay));
  217. #ifdef __APPLE__
  218. // MacOS lacks MSG_NOSIGNAL and requires SO_NOSIGPIPE instead
  219. const sockopt_t enabled = 1;
  220. if (::setsockopt(mSock, SOL_SOCKET, SO_NOSIGPIPE, &enabled, sizeof(enabled)) < 0)
  221. throw std::runtime_error("Failed to disable SIGPIPE for socket");
  222. #endif
  223. }
  224. void TcpTransport::setPoll(PollService::Direction direction) {
  225. PollService::Instance().add(
  226. mSock, {direction, direction == PollService::Direction::In ? mReadTimeout : nullopt,
  227. weak_bind(&TcpTransport::process, this, _1)});
  228. }
  229. void TcpTransport::close() {
  230. std::lock_guard lock(mSendMutex);
  231. if (mSock != INVALID_SOCKET) {
  232. PLOG_DEBUG << "Closing TCP socket";
  233. PollService::Instance().remove(mSock);
  234. ::closesocket(mSock);
  235. mSock = INVALID_SOCKET;
  236. }
  237. changeState(State::Disconnected);
  238. }
  239. bool TcpTransport::trySendQueue() {
  240. // mSendMutex must be locked
  241. while (auto next = mSendQueue.peek()) {
  242. message_ptr message = std::move(*next);
  243. size_t size = message->size();
  244. if (!trySendMessage(message)) { // replaces message
  245. mSendQueue.exchange(message);
  246. updateBufferedAmount(-ptrdiff_t(size) + ptrdiff_t(message->size()));
  247. return false;
  248. }
  249. mSendQueue.pop();
  250. updateBufferedAmount(-ptrdiff_t(size));
  251. }
  252. return true;
  253. }
  254. bool TcpTransport::trySendMessage(message_ptr &message) {
  255. // mSendMutex must be locked
  256. auto data = reinterpret_cast<const char *>(message->data());
  257. auto size = message->size();
  258. while (size) {
  259. #if defined(__APPLE__) || defined(_WIN32)
  260. int flags = 0;
  261. #else
  262. int flags = MSG_NOSIGNAL;
  263. #endif
  264. int len = ::send(mSock, data, int(size), flags);
  265. if (len < 0) {
  266. if (sockerrno == SEAGAIN || sockerrno == SEWOULDBLOCK) {
  267. if (size < message->size())
  268. message = make_message(message->end() - size, message->end());
  269. return false;
  270. } else {
  271. PLOG_ERROR << "Connection closed, errno=" << sockerrno;
  272. throw std::runtime_error("Connection closed");
  273. }
  274. }
  275. data += len;
  276. size -= len;
  277. }
  278. message = nullptr;
  279. return true;
  280. }
  281. void TcpTransport::updateBufferedAmount(ptrdiff_t delta) {
  282. // Requires mSendMutex to be locked
  283. if (delta == 0)
  284. return;
  285. mBufferedAmount = size_t(std::max(ptrdiff_t(mBufferedAmount) + delta, ptrdiff_t(0)));
  286. // Synchronously call the buffered amount callback
  287. triggerBufferedAmount(mBufferedAmount);
  288. }
  289. void TcpTransport::triggerBufferedAmount(size_t amount) {
  290. try {
  291. mBufferedAmountCallback(amount);
  292. } catch (const std::exception &e) {
  293. PLOG_WARNING << "TCP buffered amount callback: " << e.what();
  294. }
  295. }
  296. void TcpTransport::process(PollService::Event event) {
  297. auto self = weak_from_this().lock();
  298. if (!self)
  299. return;
  300. try {
  301. switch (event) {
  302. case PollService::Event::Error: {
  303. PLOG_WARNING << "TCP connection terminated";
  304. break;
  305. }
  306. case PollService::Event::Timeout: {
  307. PLOG_VERBOSE << "TCP is idle";
  308. incoming(make_message(0));
  309. setPoll(PollService::Direction::In);
  310. return;
  311. }
  312. case PollService::Event::Out: {
  313. std::lock_guard lock(mSendMutex);
  314. if (trySendQueue())
  315. setPoll(PollService::Direction::In);
  316. return;
  317. }
  318. case PollService::Event::In: {
  319. const size_t bufferSize = 4096;
  320. char buffer[bufferSize];
  321. int len;
  322. while ((len = ::recv(mSock, buffer, bufferSize, 0)) > 0) {
  323. auto *b = reinterpret_cast<byte *>(buffer);
  324. incoming(make_message(b, b + len));
  325. }
  326. if (len == 0)
  327. break; // clean close
  328. if (sockerrno != SEAGAIN && sockerrno != SEWOULDBLOCK) {
  329. PLOG_WARNING << "TCP connection lost";
  330. break;
  331. }
  332. return;
  333. }
  334. default:
  335. // Ignore
  336. return;
  337. }
  338. } catch (const std::exception &e) {
  339. PLOG_ERROR << e.what();
  340. }
  341. PLOG_INFO << "TCP disconnected";
  342. PollService::Instance().remove(mSock);
  343. changeState(State::Disconnected);
  344. recv(nullptr);
  345. }
  346. void TcpTransport::processConnect(PollService::Event event) {
  347. try {
  348. if (event == PollService::Event::Error)
  349. throw std::runtime_error("TCP connection failed");
  350. if (event == PollService::Event::Timeout)
  351. throw std::runtime_error("TCP connection timed out");
  352. if (event != PollService::Event::Out)
  353. return;
  354. int err = 0;
  355. socklen_t errlen = sizeof(err);
  356. if (::getsockopt(mSock, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&err),
  357. &errlen) != 0)
  358. throw std::runtime_error("Failed to get socket error code");
  359. if (err != 0) {
  360. std::ostringstream msg;
  361. msg << "TCP connection failed, errno=" << err;
  362. throw std::runtime_error(msg.str());
  363. }
  364. // Success
  365. PLOG_INFO << "TCP connected";
  366. changeState(State::Connected);
  367. setPoll(PollService::Direction::In);
  368. } catch (const std::exception &e) {
  369. PLOG_DEBUG << e.what();
  370. PollService::Instance().remove(mSock);
  371. ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
  372. }
  373. }
  374. } // namespace rtc::impl
  375. #endif