tcptransport.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. /**
  2. * Copyright (c) 2020 Paul-Louis Ageneau
  3. *
  4. * This Source Code Form is subject to the terms of the Mozilla Public
  5. * License, v. 2.0. If a copy of the MPL was not distributed with this
  6. * file, You can obtain one at https://mozilla.org/MPL/2.0/.
  7. */
  8. #include "tcptransport.hpp"
  9. #include "internals.hpp"
  10. #include "threadpool.hpp"
  11. #if RTC_ENABLE_WEBSOCKET
  12. #ifndef _WIN32
  13. #include <fcntl.h>
  14. #include <unistd.h>
  15. #endif
  16. #include <chrono>
  17. namespace rtc::impl {
  18. using namespace std::placeholders;
  19. using namespace std::chrono_literals;
  20. using std::chrono::duration_cast;
  21. using std::chrono::milliseconds;
  22. TcpTransport::TcpTransport(string hostname, string service, state_callback callback)
  23. : Transport(nullptr, std::move(callback)), mIsActive(true), mHostname(std::move(hostname)),
  24. mService(std::move(service)), mSock(INVALID_SOCKET) {
  25. PLOG_DEBUG << "Initializing TCP transport";
  26. }
  27. TcpTransport::TcpTransport(socket_t sock, state_callback callback)
  28. : Transport(nullptr, std::move(callback)), mIsActive(false), mSock(sock) {
  29. PLOG_DEBUG << "Initializing TCP transport with socket";
  30. // Configure socket
  31. configureSocket();
  32. // Retrieve hostname and service
  33. struct sockaddr_storage addr;
  34. socklen_t addrlen = sizeof(addr);
  35. if (::getpeername(mSock, reinterpret_cast<struct sockaddr *>(&addr), &addrlen) < 0)
  36. throw std::runtime_error("getsockname failed");
  37. char node[MAX_NUMERICNODE_LEN];
  38. char serv[MAX_NUMERICSERV_LEN];
  39. if (::getnameinfo(reinterpret_cast<struct sockaddr *>(&addr), addrlen, node,
  40. MAX_NUMERICNODE_LEN, serv, MAX_NUMERICSERV_LEN,
  41. NI_NUMERICHOST | NI_NUMERICSERV) != 0)
  42. throw std::runtime_error("getnameinfo failed");
  43. mHostname = node;
  44. mService = serv;
  45. }
  46. TcpTransport::~TcpTransport() {
  47. close();
  48. }
  49. void TcpTransport::onBufferedAmount(amount_callback callback) {
  50. mBufferedAmountCallback = std::move(callback);
  51. }
  52. void TcpTransport::setReadTimeout(std::chrono::milliseconds readTimeout) {
  53. mReadTimeout = readTimeout;
  54. }
  55. void TcpTransport::start() {
  56. if (mSock == INVALID_SOCKET) {
  57. connect();
  58. } else {
  59. changeState(State::Connected);
  60. setPoll(PollService::Direction::In);
  61. }
  62. }
  63. bool TcpTransport::send(message_ptr message) {
  64. std::lock_guard lock(mSendMutex);
  65. if (state() != State::Connected)
  66. throw std::runtime_error("Connection is not open");
  67. if (!message || message->size() == 0)
  68. return trySendQueue();
  69. PLOG_VERBOSE << "Send size=" << message->size();
  70. return outgoing(message);
  71. }
  72. void TcpTransport::incoming(message_ptr message) {
  73. if (!message)
  74. return;
  75. PLOG_VERBOSE << "Incoming size=" << message->size();
  76. recv(message);
  77. }
  78. bool TcpTransport::outgoing(message_ptr message) {
  79. // mSendMutex must be locked
  80. // Flush the queue, and if nothing is pending, try to send directly
  81. if (trySendQueue() && trySendMessage(message))
  82. return true;
  83. mSendQueue.push(message);
  84. updateBufferedAmount(ptrdiff_t(message->size()));
  85. setPoll(PollService::Direction::Both);
  86. return false;
  87. }
  88. bool TcpTransport::isActive() const { return mIsActive; }
  89. string TcpTransport::remoteAddress() const { return mHostname + ':' + mService; }
  90. void TcpTransport::connect() {
  91. if (state() == State::Connecting)
  92. throw std::logic_error("TCP connection is already in progress");
  93. if (state() == State::Connected)
  94. throw std::logic_error("TCP is already connected");
  95. PLOG_DEBUG << "Connecting to " << mHostname << ":" << mService;
  96. changeState(State::Connecting);
  97. ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::resolve, this));
  98. }
  99. void TcpTransport::resolve() {
  100. std::lock_guard lock(mSendMutex);
  101. mResolved.clear();
  102. if (state() != State::Connecting)
  103. return; // Cancelled
  104. try {
  105. PLOG_DEBUG << "Resolving " << mHostname << ":" << mService;
  106. struct addrinfo hints = {};
  107. hints.ai_family = AF_UNSPEC;
  108. hints.ai_socktype = SOCK_STREAM;
  109. hints.ai_protocol = IPPROTO_TCP;
  110. hints.ai_flags = AI_ADDRCONFIG;
  111. struct addrinfo *result = nullptr;
  112. if (getaddrinfo(mHostname.c_str(), mService.c_str(), &hints, &result))
  113. throw std::runtime_error("Resolution failed for \"" + mHostname + ":" + mService +
  114. "\"");
  115. try {
  116. struct addrinfo *ai = result;
  117. while (ai) {
  118. struct sockaddr_storage addr;
  119. std::memcpy(&addr, ai->ai_addr, ai->ai_addrlen);
  120. mResolved.emplace_back(addr, socklen_t(ai->ai_addrlen));
  121. ai = ai->ai_next;
  122. }
  123. } catch (...) {
  124. freeaddrinfo(result);
  125. throw;
  126. }
  127. freeaddrinfo(result);
  128. } catch (const std::exception &e) {
  129. PLOG_WARNING << e.what();
  130. changeState(State::Failed);
  131. return;
  132. }
  133. ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
  134. }
  135. void TcpTransport::attempt() {
  136. std::lock_guard lock(mSendMutex);
  137. if (state() != State::Connecting)
  138. return; // Cancelled
  139. if (mSock == INVALID_SOCKET) {
  140. ::closesocket(mSock);
  141. mSock = INVALID_SOCKET;
  142. }
  143. if (mResolved.empty()) {
  144. PLOG_WARNING << "Connection to " << mHostname << ":" << mService << " failed";
  145. changeState(State::Failed);
  146. return;
  147. }
  148. try {
  149. auto [addr, addrlen] = mResolved.front();
  150. mResolved.pop_front();
  151. createSocket(reinterpret_cast<const struct sockaddr *>(&addr), addrlen);
  152. } catch (const std::runtime_error &e) {
  153. PLOG_DEBUG << e.what();
  154. ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
  155. return;
  156. }
  157. // Poll out event callback
  158. auto callback = [this](PollService::Event event) {
  159. try {
  160. if (event == PollService::Event::Error)
  161. throw std::runtime_error("TCP connection failed");
  162. if (event == PollService::Event::Timeout)
  163. throw std::runtime_error("TCP connection timed out");
  164. if (event != PollService::Event::Out)
  165. return;
  166. int err = 0;
  167. socklen_t errlen = sizeof(err);
  168. if (::getsockopt(mSock, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&err),
  169. &errlen) != 0)
  170. throw std::runtime_error("Failed to get socket error code");
  171. if (err != 0) {
  172. std::ostringstream msg;
  173. msg << "TCP connection failed, errno=" << err;
  174. throw std::runtime_error(msg.str());
  175. }
  176. // Success
  177. PLOG_INFO << "TCP connected";
  178. changeState(State::Connected);
  179. setPoll(PollService::Direction::In);
  180. } catch (const std::exception &e) {
  181. PLOG_DEBUG << e.what();
  182. PollService::Instance().remove(mSock);
  183. ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
  184. }
  185. };
  186. const auto timeout = 10s;
  187. PollService::Instance().add(mSock, {PollService::Direction::Out, timeout, std::move(callback)});
  188. }
  189. void TcpTransport::createSocket(const struct sockaddr *addr, socklen_t addrlen) {
  190. try {
  191. char node[MAX_NUMERICNODE_LEN];
  192. char serv[MAX_NUMERICSERV_LEN];
  193. if (getnameinfo(addr, addrlen, node, MAX_NUMERICNODE_LEN, serv, MAX_NUMERICSERV_LEN,
  194. NI_NUMERICHOST | NI_NUMERICSERV) == 0) {
  195. PLOG_DEBUG << "Trying address " << node << ":" << serv;
  196. }
  197. PLOG_VERBOSE << "Creating TCP socket";
  198. // Create socket
  199. mSock = ::socket(addr->sa_family, SOCK_STREAM, IPPROTO_TCP);
  200. if (mSock == INVALID_SOCKET)
  201. throw std::runtime_error("TCP socket creation failed");
  202. // Configure socket
  203. configureSocket();
  204. // Initiate connection
  205. int ret = ::connect(mSock, addr, addrlen);
  206. if (ret < 0 && sockerrno != SEINPROGRESS && sockerrno != SEWOULDBLOCK) {
  207. std::ostringstream msg;
  208. msg << "TCP connection to " << node << ":" << serv << " failed, errno=" << sockerrno;
  209. throw std::runtime_error(msg.str());
  210. }
  211. } catch (...) {
  212. if (mSock != INVALID_SOCKET) {
  213. ::closesocket(mSock);
  214. mSock = INVALID_SOCKET;
  215. }
  216. throw;
  217. }
  218. }
  219. void TcpTransport::configureSocket() {
  220. // Set non-blocking
  221. ctl_t nbio = 1;
  222. if (::ioctlsocket(mSock, FIONBIO, &nbio) < 0)
  223. throw std::runtime_error("Failed to set socket non-blocking mode");
  224. // Disable the Nagle algorithm
  225. int nodelay = 1;
  226. ::setsockopt(mSock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<const char *>(&nodelay),
  227. sizeof(nodelay));
  228. #ifdef __APPLE__
  229. // MacOS lacks MSG_NOSIGNAL and requires SO_NOSIGPIPE instead
  230. const sockopt_t enabled = 1;
  231. if (::setsockopt(mSock, SOL_SOCKET, SO_NOSIGPIPE, &enabled, sizeof(enabled)) < 0)
  232. throw std::runtime_error("Failed to disable SIGPIPE for socket");
  233. #endif
  234. }
  235. void TcpTransport::setPoll(PollService::Direction direction) {
  236. PollService::Instance().add(
  237. mSock, {direction, direction == PollService::Direction::In ? mReadTimeout : nullopt,
  238. std::bind(&TcpTransport::process, this, _1)});
  239. }
  240. void TcpTransport::close() {
  241. std::lock_guard lock(mSendMutex);
  242. if (mSock != INVALID_SOCKET) {
  243. PLOG_DEBUG << "Closing TCP socket";
  244. PollService::Instance().remove(mSock);
  245. ::closesocket(mSock);
  246. mSock = INVALID_SOCKET;
  247. }
  248. changeState(State::Disconnected);
  249. }
  250. bool TcpTransport::trySendQueue() {
  251. // mSendMutex must be locked
  252. while (auto next = mSendQueue.peek()) {
  253. message_ptr message = std::move(*next);
  254. size_t size = message->size();
  255. if (!trySendMessage(message)) { // replaces message
  256. mSendQueue.exchange(message);
  257. updateBufferedAmount(-ptrdiff_t(size) + ptrdiff_t(message->size()));
  258. return false;
  259. }
  260. mSendQueue.pop();
  261. updateBufferedAmount(-ptrdiff_t(size));
  262. }
  263. return true;
  264. }
  265. bool TcpTransport::trySendMessage(message_ptr &message) {
  266. // mSendMutex must be locked
  267. auto data = reinterpret_cast<const char *>(message->data());
  268. auto size = message->size();
  269. while (size) {
  270. #if defined(__APPLE__) || defined(_WIN32)
  271. int flags = 0;
  272. #else
  273. int flags = MSG_NOSIGNAL;
  274. #endif
  275. int len = ::send(mSock, data, int(size), flags);
  276. if (len < 0) {
  277. if (sockerrno == SEAGAIN || sockerrno == SEWOULDBLOCK) {
  278. message = make_message(message->end() - size, message->end());
  279. return false;
  280. } else {
  281. PLOG_ERROR << "Connection closed, errno=" << sockerrno;
  282. throw std::runtime_error("Connection closed");
  283. }
  284. }
  285. data += len;
  286. size -= len;
  287. }
  288. message = nullptr;
  289. return true;
  290. }
  291. void TcpTransport::updateBufferedAmount(ptrdiff_t delta) {
  292. // Requires mSendMutex to be locked
  293. if (delta == 0)
  294. return;
  295. mBufferedAmount = size_t(std::max(ptrdiff_t(mBufferedAmount) + delta, ptrdiff_t(0)));
  296. // Synchronously call the buffered amount callback
  297. triggerBufferedAmount(mBufferedAmount);
  298. }
  299. void TcpTransport::triggerBufferedAmount(size_t amount) {
  300. try {
  301. mBufferedAmountCallback(amount);
  302. } catch (const std::exception &e) {
  303. PLOG_WARNING << "TCP buffered amount callback: " << e.what();
  304. }
  305. }
  306. void TcpTransport::process(PollService::Event event) {
  307. auto self = weak_from_this().lock();
  308. if (!self)
  309. return;
  310. try {
  311. switch (event) {
  312. case PollService::Event::Error: {
  313. PLOG_WARNING << "TCP connection terminated";
  314. break;
  315. }
  316. case PollService::Event::Timeout: {
  317. PLOG_VERBOSE << "TCP is idle";
  318. incoming(make_message(0));
  319. setPoll(PollService::Direction::In);
  320. return;
  321. }
  322. case PollService::Event::Out: {
  323. if (trySendQueue())
  324. setPoll(PollService::Direction::In);
  325. return;
  326. }
  327. case PollService::Event::In: {
  328. const size_t bufferSize = 4096;
  329. char buffer[bufferSize];
  330. int len;
  331. while ((len = ::recv(mSock, buffer, bufferSize, 0)) > 0) {
  332. auto *b = reinterpret_cast<byte *>(buffer);
  333. incoming(make_message(b, b + len));
  334. }
  335. if (len == 0)
  336. break; // clean close
  337. if (sockerrno != SEAGAIN && sockerrno != SEWOULDBLOCK) {
  338. PLOG_WARNING << "TCP connection lost";
  339. break;
  340. }
  341. return;
  342. }
  343. default:
  344. // Ignore
  345. return;
  346. }
  347. } catch (const std::exception &e) {
  348. PLOG_ERROR << e.what();
  349. }
  350. PLOG_INFO << "TCP disconnected";
  351. PollService::Instance().remove(mSock);
  352. changeState(State::Disconnected);
  353. recv(nullptr);
  354. }
  355. } // namespace rtc::impl
  356. #endif