tcptransport.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. /**
  2. * Copyright (c) 2020 Paul-Louis Ageneau
  3. *
  4. * This Source Code Form is subject to the terms of the Mozilla Public
  5. * License, v. 2.0. If a copy of the MPL was not distributed with this
  6. * file, You can obtain one at https://mozilla.org/MPL/2.0/.
  7. */
  8. #include "tcptransport.hpp"
  9. #include "internals.hpp"
  10. #include "threadpool.hpp"
  11. #if RTC_ENABLE_WEBSOCKET
  12. #ifndef _WIN32
  13. #include <fcntl.h>
  14. #include <unistd.h>
  15. #endif
  16. #include <chrono>
  17. namespace rtc::impl {
  18. using namespace std::placeholders;
  19. using namespace std::chrono_literals;
  20. using std::chrono::duration_cast;
  21. using std::chrono::milliseconds;
  22. TcpTransport::TcpTransport(string hostname, string service, state_callback callback)
  23. : Transport(nullptr, std::move(callback)), mIsActive(true), mHostname(std::move(hostname)),
  24. mService(std::move(service)), mSock(INVALID_SOCKET) {
  25. PLOG_DEBUG << "Initializing TCP transport";
  26. }
  27. TcpTransport::TcpTransport(socket_t sock, state_callback callback)
  28. : Transport(nullptr, std::move(callback)), mIsActive(false), mSock(sock) {
  29. PLOG_DEBUG << "Initializing TCP transport with socket";
  30. // Configure socket
  31. configureSocket();
  32. // Retrieve hostname and service
  33. struct sockaddr_storage addr;
  34. socklen_t addrlen = sizeof(addr);
  35. if (::getpeername(mSock, reinterpret_cast<struct sockaddr *>(&addr), &addrlen) < 0)
  36. throw std::runtime_error("getsockname failed");
  37. char node[MAX_NUMERICNODE_LEN];
  38. char serv[MAX_NUMERICSERV_LEN];
  39. if (::getnameinfo(reinterpret_cast<struct sockaddr *>(&addr), addrlen, node,
  40. MAX_NUMERICNODE_LEN, serv, MAX_NUMERICSERV_LEN,
  41. NI_NUMERICHOST | NI_NUMERICSERV) != 0)
  42. throw std::runtime_error("getnameinfo failed");
  43. mHostname = node;
  44. mService = serv;
  45. }
  46. TcpTransport::~TcpTransport() { close(); }
  47. void TcpTransport::onBufferedAmount(amount_callback callback) {
  48. mBufferedAmountCallback = std::move(callback);
  49. }
  50. void TcpTransport::setReadTimeout(std::chrono::milliseconds readTimeout) {
  51. mReadTimeout = readTimeout;
  52. }
  53. void TcpTransport::start() {
  54. if (mSock == INVALID_SOCKET) {
  55. connect();
  56. } else {
  57. changeState(State::Connected);
  58. setPoll(PollService::Direction::In);
  59. }
  60. }
  61. bool TcpTransport::send(message_ptr message) {
  62. std::lock_guard lock(mSendMutex);
  63. if (state() != State::Connected)
  64. throw std::runtime_error("Connection is not open");
  65. if (!message || message->size() == 0)
  66. return trySendQueue();
  67. PLOG_VERBOSE << "Send size=" << message->size();
  68. return outgoing(message);
  69. }
  70. void TcpTransport::incoming(message_ptr message) {
  71. if (!message)
  72. return;
  73. PLOG_VERBOSE << "Incoming size=" << message->size();
  74. recv(message);
  75. }
  76. bool TcpTransport::outgoing(message_ptr message) {
  77. // mSendMutex must be locked
  78. // Flush the queue, and if nothing is pending, try to send directly
  79. if (trySendQueue() && trySendMessage(message))
  80. return true;
  81. mSendQueue.push(message);
  82. updateBufferedAmount(ptrdiff_t(message->size()));
  83. setPoll(PollService::Direction::Both);
  84. return false;
  85. }
  86. bool TcpTransport::isActive() const { return mIsActive; }
  87. string TcpTransport::remoteAddress() const { return mHostname + ':' + mService; }
  88. void TcpTransport::connect() {
  89. if (state() == State::Connecting)
  90. throw std::logic_error("TCP connection is already in progress");
  91. if (state() == State::Connected)
  92. throw std::logic_error("TCP is already connected");
  93. PLOG_DEBUG << "Connecting to " << mHostname << ":" << mService;
  94. changeState(State::Connecting);
  95. ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::resolve, this));
  96. }
  97. void TcpTransport::resolve() {
  98. std::lock_guard lock(mSendMutex);
  99. mResolved.clear();
  100. if (state() != State::Connecting)
  101. return; // Cancelled
  102. try {
  103. PLOG_DEBUG << "Resolving " << mHostname << ":" << mService;
  104. struct addrinfo hints = {};
  105. hints.ai_family = AF_UNSPEC;
  106. hints.ai_socktype = SOCK_STREAM;
  107. hints.ai_protocol = IPPROTO_TCP;
  108. hints.ai_flags = AI_ADDRCONFIG;
  109. struct addrinfo *result = nullptr;
  110. if (getaddrinfo(mHostname.c_str(), mService.c_str(), &hints, &result))
  111. throw std::runtime_error("Resolution failed for \"" + mHostname + ":" + mService +
  112. "\"");
  113. try {
  114. struct addrinfo *ai = result;
  115. while (ai) {
  116. struct sockaddr_storage addr;
  117. std::memcpy(&addr, ai->ai_addr, ai->ai_addrlen);
  118. mResolved.emplace_back(addr, socklen_t(ai->ai_addrlen));
  119. ai = ai->ai_next;
  120. }
  121. } catch (...) {
  122. freeaddrinfo(result);
  123. throw;
  124. }
  125. freeaddrinfo(result);
  126. } catch (const std::exception &e) {
  127. PLOG_WARNING << e.what();
  128. changeState(State::Failed);
  129. return;
  130. }
  131. ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
  132. }
  133. void TcpTransport::attempt() {
  134. std::lock_guard lock(mSendMutex);
  135. if (state() != State::Connecting)
  136. return; // Cancelled
  137. if (mSock == INVALID_SOCKET) {
  138. ::closesocket(mSock);
  139. mSock = INVALID_SOCKET;
  140. }
  141. if (mResolved.empty()) {
  142. PLOG_WARNING << "Connection to " << mHostname << ":" << mService << " failed";
  143. changeState(State::Failed);
  144. return;
  145. }
  146. try {
  147. auto [addr, addrlen] = mResolved.front();
  148. mResolved.pop_front();
  149. createSocket(reinterpret_cast<const struct sockaddr *>(&addr), addrlen);
  150. } catch (const std::runtime_error &e) {
  151. PLOG_DEBUG << e.what();
  152. ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
  153. return;
  154. }
  155. // Poll out event callback
  156. auto callback = [this](PollService::Event event) {
  157. try {
  158. if (event == PollService::Event::Error)
  159. throw std::runtime_error("TCP connection failed");
  160. if (event == PollService::Event::Timeout)
  161. throw std::runtime_error("TCP connection timed out");
  162. if (event != PollService::Event::Out)
  163. return;
  164. int err = 0;
  165. socklen_t errlen = sizeof(err);
  166. if (::getsockopt(mSock, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&err),
  167. &errlen) != 0)
  168. throw std::runtime_error("Failed to get socket error code");
  169. if (err != 0) {
  170. std::ostringstream msg;
  171. msg << "TCP connection failed, errno=" << err;
  172. throw std::runtime_error(msg.str());
  173. }
  174. // Success
  175. PLOG_INFO << "TCP connected";
  176. changeState(State::Connected);
  177. setPoll(PollService::Direction::In);
  178. } catch (const std::exception &e) {
  179. PLOG_DEBUG << e.what();
  180. PollService::Instance().remove(mSock);
  181. ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
  182. }
  183. };
  184. const auto timeout = 10s;
  185. PollService::Instance().add(mSock, {PollService::Direction::Out, timeout, std::move(callback)});
  186. }
  187. void TcpTransport::createSocket(const struct sockaddr *addr, socklen_t addrlen) {
  188. try {
  189. char node[MAX_NUMERICNODE_LEN];
  190. char serv[MAX_NUMERICSERV_LEN];
  191. if (getnameinfo(addr, addrlen, node, MAX_NUMERICNODE_LEN, serv, MAX_NUMERICSERV_LEN,
  192. NI_NUMERICHOST | NI_NUMERICSERV) == 0) {
  193. PLOG_DEBUG << "Trying address " << node << ":" << serv;
  194. }
  195. PLOG_VERBOSE << "Creating TCP socket";
  196. // Create socket
  197. mSock = ::socket(addr->sa_family, SOCK_STREAM, IPPROTO_TCP);
  198. if (mSock == INVALID_SOCKET)
  199. throw std::runtime_error("TCP socket creation failed");
  200. // Configure socket
  201. configureSocket();
  202. // Initiate connection
  203. int ret = ::connect(mSock, addr, addrlen);
  204. if (ret < 0 && sockerrno != SEINPROGRESS && sockerrno != SEWOULDBLOCK) {
  205. std::ostringstream msg;
  206. msg << "TCP connection to " << node << ":" << serv << " failed, errno=" << sockerrno;
  207. throw std::runtime_error(msg.str());
  208. }
  209. } catch (...) {
  210. if (mSock != INVALID_SOCKET) {
  211. ::closesocket(mSock);
  212. mSock = INVALID_SOCKET;
  213. }
  214. throw;
  215. }
  216. }
  217. void TcpTransport::configureSocket() {
  218. // Set non-blocking
  219. ctl_t nbio = 1;
  220. if (::ioctlsocket(mSock, FIONBIO, &nbio) < 0)
  221. throw std::runtime_error("Failed to set socket non-blocking mode");
  222. // Disable the Nagle algorithm
  223. int nodelay = 1;
  224. ::setsockopt(mSock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<const char *>(&nodelay),
  225. sizeof(nodelay));
  226. #ifdef __APPLE__
  227. // MacOS lacks MSG_NOSIGNAL and requires SO_NOSIGPIPE instead
  228. const sockopt_t enabled = 1;
  229. if (::setsockopt(mSock, SOL_SOCKET, SO_NOSIGPIPE, &enabled, sizeof(enabled)) < 0)
  230. throw std::runtime_error("Failed to disable SIGPIPE for socket");
  231. #endif
  232. }
  233. void TcpTransport::setPoll(PollService::Direction direction) {
  234. PollService::Instance().add(
  235. mSock, {direction, direction == PollService::Direction::In ? mReadTimeout : nullopt,
  236. std::bind(&TcpTransport::process, this, _1)});
  237. }
  238. void TcpTransport::close() {
  239. std::lock_guard lock(mSendMutex);
  240. if (mSock != INVALID_SOCKET) {
  241. PLOG_DEBUG << "Closing TCP socket";
  242. PollService::Instance().remove(mSock);
  243. ::closesocket(mSock);
  244. mSock = INVALID_SOCKET;
  245. }
  246. changeState(State::Disconnected);
  247. }
  248. bool TcpTransport::trySendQueue() {
  249. // mSendMutex must be locked
  250. while (auto next = mSendQueue.peek()) {
  251. message_ptr message = std::move(*next);
  252. size_t size = message->size();
  253. if (!trySendMessage(message)) { // replaces message
  254. mSendQueue.exchange(message);
  255. updateBufferedAmount(-ptrdiff_t(size) + ptrdiff_t(message->size()));
  256. return false;
  257. }
  258. mSendQueue.pop();
  259. updateBufferedAmount(-ptrdiff_t(size));
  260. }
  261. return true;
  262. }
  263. bool TcpTransport::trySendMessage(message_ptr &message) {
  264. // mSendMutex must be locked
  265. auto data = reinterpret_cast<const char *>(message->data());
  266. auto size = message->size();
  267. while (size) {
  268. #if defined(__APPLE__) || defined(_WIN32)
  269. int flags = 0;
  270. #else
  271. int flags = MSG_NOSIGNAL;
  272. #endif
  273. int len = ::send(mSock, data, int(size), flags);
  274. if (len < 0) {
  275. if (sockerrno == SEAGAIN || sockerrno == SEWOULDBLOCK) {
  276. message = make_message(message->end() - size, message->end());
  277. return false;
  278. } else {
  279. PLOG_ERROR << "Connection closed, errno=" << sockerrno;
  280. throw std::runtime_error("Connection closed");
  281. }
  282. }
  283. data += len;
  284. size -= len;
  285. }
  286. message = nullptr;
  287. return true;
  288. }
  289. void TcpTransport::updateBufferedAmount(ptrdiff_t delta) {
  290. // Requires mSendMutex to be locked
  291. if (delta == 0)
  292. return;
  293. mBufferedAmount = size_t(std::max(ptrdiff_t(mBufferedAmount) + delta, ptrdiff_t(0)));
  294. // Synchronously call the buffered amount callback
  295. triggerBufferedAmount(mBufferedAmount);
  296. }
  297. void TcpTransport::triggerBufferedAmount(size_t amount) {
  298. try {
  299. mBufferedAmountCallback(amount);
  300. } catch (const std::exception &e) {
  301. PLOG_WARNING << "TCP buffered amount callback: " << e.what();
  302. }
  303. }
  304. void TcpTransport::process(PollService::Event event) {
  305. auto self = weak_from_this().lock();
  306. if (!self)
  307. return;
  308. try {
  309. switch (event) {
  310. case PollService::Event::Error: {
  311. PLOG_WARNING << "TCP connection terminated";
  312. break;
  313. }
  314. case PollService::Event::Timeout: {
  315. PLOG_VERBOSE << "TCP is idle";
  316. incoming(make_message(0));
  317. setPoll(PollService::Direction::In);
  318. return;
  319. }
  320. case PollService::Event::Out: {
  321. if (trySendQueue())
  322. setPoll(PollService::Direction::In);
  323. return;
  324. }
  325. case PollService::Event::In: {
  326. const size_t bufferSize = 4096;
  327. char buffer[bufferSize];
  328. int len;
  329. while ((len = ::recv(mSock, buffer, bufferSize, 0)) > 0) {
  330. auto *b = reinterpret_cast<byte *>(buffer);
  331. incoming(make_message(b, b + len));
  332. }
  333. if (len == 0)
  334. break; // clean close
  335. if (sockerrno != SEAGAIN && sockerrno != SEWOULDBLOCK) {
  336. PLOG_WARNING << "TCP connection lost";
  337. break;
  338. }
  339. return;
  340. }
  341. default:
  342. // Ignore
  343. return;
  344. }
  345. } catch (const std::exception &e) {
  346. PLOG_ERROR << e.what();
  347. }
  348. PLOG_INFO << "TCP disconnected";
  349. PollService::Instance().remove(mSock);
  350. changeState(State::Disconnected);
  351. recv(nullptr);
  352. }
  353. } // namespace rtc::impl
  354. #endif