tcptransport.cpp 12 KB


  1. /**
  2. * Copyright (c) 2020 Paul-Louis Ageneau
  3. *
  4. * This Source Code Form is subject to the terms of the Mozilla Public
  5. * License, v. 2.0. If a copy of the MPL was not distributed with this
  6. * file, You can obtain one at https://mozilla.org/MPL/2.0/.
  7. */
  8. #include "tcptransport.hpp"
  9. #include "internals.hpp"
  10. #include "threadpool.hpp"
  11. #if RTC_ENABLE_WEBSOCKET
  12. #ifndef _WIN32
  13. #include <fcntl.h>
  14. #include <unistd.h>
  15. #endif
  16. #include <chrono>
  17. namespace rtc::impl {
  18. using namespace std::placeholders;
  19. using namespace std::chrono_literals;
  20. using std::chrono::duration_cast;
  21. using std::chrono::milliseconds;
  22. namespace {
  23. bool unmap_inet6_v4mapped(struct sockaddr *sa, socklen_t *len) {
  24. if (sa->sa_family != AF_INET6)
  25. return false;
  26. const auto *sin6 = reinterpret_cast<struct sockaddr_in6 *>(sa);
  27. if (!IN6_IS_ADDR_V4MAPPED(&sin6->sin6_addr))
  28. return false;
  29. struct sockaddr_in6 copy = *sin6;
  30. sin6 = &copy;
  31. auto *sin = reinterpret_cast<struct sockaddr_in *>(sa);
  32. std::memset(sin, 0, sizeof(*sin));
  33. sin->sin_family = AF_INET;
  34. sin->sin_port = sin6->sin6_port;
  35. std::memcpy(&sin->sin_addr, reinterpret_cast<const uint8_t *>(&sin6->sin6_addr) + 12, 4);
  36. *len = sizeof(*sin);
  37. return true;
  38. }
  39. }
  40. TcpTransport::TcpTransport(string hostname, string service, state_callback callback)
  41. : Transport(nullptr, std::move(callback)), mIsActive(true), mHostname(std::move(hostname)),
  42. mService(std::move(service)), mSock(INVALID_SOCKET) {
  43. PLOG_DEBUG << "Initializing TCP transport";
  44. }
  45. TcpTransport::TcpTransport(socket_t sock, state_callback callback)
  46. : Transport(nullptr, std::move(callback)), mIsActive(false), mSock(sock) {
  47. PLOG_DEBUG << "Initializing TCP transport with socket";
  48. // Configure socket
  49. configureSocket();
  50. // Retrieve hostname and service
  51. struct sockaddr_storage addr;
  52. socklen_t addrlen = sizeof(addr);
  53. if (::getpeername(mSock, reinterpret_cast<struct sockaddr *>(&addr), &addrlen) < 0)
  54. throw std::runtime_error("getpeername failed");
  55. unmap_inet6_v4mapped(reinterpret_cast<struct sockaddr *>(&addr), &addrlen);
  56. char node[MAX_NUMERICNODE_LEN];
  57. char serv[MAX_NUMERICSERV_LEN];
  58. if (::getnameinfo(reinterpret_cast<struct sockaddr *>(&addr), addrlen, node,
  59. MAX_NUMERICNODE_LEN, serv, MAX_NUMERICSERV_LEN,
  60. NI_NUMERICHOST | NI_NUMERICSERV) != 0)
  61. throw std::runtime_error("getnameinfo failed");
  62. mHostname = node;
  63. mService = serv;
  64. }
  65. TcpTransport::~TcpTransport() { close(); }
  66. void TcpTransport::onBufferedAmount(amount_callback callback) {
  67. mBufferedAmountCallback = std::move(callback);
  68. }
  69. void TcpTransport::setReadTimeout(std::chrono::milliseconds readTimeout) {
  70. mReadTimeout = readTimeout;
  71. }
  72. void TcpTransport::start() {
  73. if (mSock == INVALID_SOCKET) {
  74. connect();
  75. } else {
  76. changeState(State::Connected);
  77. setPoll(PollService::Direction::In);
  78. }
  79. }
  80. bool TcpTransport::send(message_ptr message) {
  81. std::lock_guard lock(mSendMutex);
  82. if (state() != State::Connected)
  83. throw std::runtime_error("Connection is not open");
  84. if (!message || message->size() == 0)
  85. return trySendQueue();
  86. PLOG_VERBOSE << "Send size=" << message->size();
  87. return outgoing(message);
  88. }
  89. void TcpTransport::incoming(message_ptr message) {
  90. if (!message)
  91. return;
  92. PLOG_VERBOSE << "Incoming size=" << message->size();
  93. recv(message);
  94. }
  95. bool TcpTransport::outgoing(message_ptr message) {
  96. // mSendMutex must be locked
  97. // Flush the queue, and if nothing is pending, try to send directly
  98. if (trySendQueue() && trySendMessage(message))
  99. return true;
  100. mSendQueue.push(message);
  101. updateBufferedAmount(ptrdiff_t(message->size()));
  102. setPoll(PollService::Direction::Both);
  103. return false;
  104. }
  105. bool TcpTransport::isActive() const { return mIsActive; }
  106. string TcpTransport::remoteAddress() const { return mHostname + ':' + mService; }
  107. void TcpTransport::connect() {
  108. if (state() == State::Connecting)
  109. throw std::logic_error("TCP connection is already in progress");
  110. if (state() == State::Connected)
  111. throw std::logic_error("TCP is already connected");
  112. PLOG_DEBUG << "Connecting to " << mHostname << ":" << mService;
  113. changeState(State::Connecting);
  114. ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::resolve, this));
  115. }
  116. void TcpTransport::resolve() {
  117. std::lock_guard lock(mSendMutex);
  118. mResolved.clear();
  119. if (state() != State::Connecting)
  120. return; // Cancelled
  121. try {
  122. PLOG_DEBUG << "Resolving " << mHostname << ":" << mService;
  123. struct addrinfo hints = {};
  124. hints.ai_family = AF_UNSPEC;
  125. hints.ai_socktype = SOCK_STREAM;
  126. hints.ai_protocol = IPPROTO_TCP;
  127. hints.ai_flags = AI_ADDRCONFIG;
  128. struct addrinfo *result = nullptr;
  129. if (getaddrinfo(mHostname.c_str(), mService.c_str(), &hints, &result))
  130. throw std::runtime_error("Resolution failed for \"" + mHostname + ":" + mService +
  131. "\"");
  132. try {
  133. struct addrinfo *ai = result;
  134. while (ai) {
  135. struct sockaddr_storage addr;
  136. std::memcpy(&addr, ai->ai_addr, ai->ai_addrlen);
  137. mResolved.emplace_back(addr, socklen_t(ai->ai_addrlen));
  138. ai = ai->ai_next;
  139. }
  140. } catch (...) {
  141. freeaddrinfo(result);
  142. throw;
  143. }
  144. freeaddrinfo(result);
  145. } catch (const std::exception &e) {
  146. PLOG_WARNING << e.what();
  147. changeState(State::Failed);
  148. return;
  149. }
  150. ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
  151. }
  152. void TcpTransport::attempt() {
  153. try {
  154. std::lock_guard lock(mSendMutex);
  155. if (state() != State::Connecting)
  156. return; // Cancelled
  157. if (mSock == INVALID_SOCKET) {
  158. ::closesocket(mSock);
  159. mSock = INVALID_SOCKET;
  160. }
  161. if (mResolved.empty()) {
  162. PLOG_WARNING << "Connection to " << mHostname << ":" << mService << " failed";
  163. changeState(State::Failed);
  164. return;
  165. }
  166. auto [addr, addrlen] = mResolved.front();
  167. mResolved.pop_front();
  168. createSocket(reinterpret_cast<const struct sockaddr *>(&addr), addrlen);
  169. } catch (const std::runtime_error &e) {
  170. PLOG_DEBUG << e.what();
  171. ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
  172. return;
  173. }
  174. // Poll out event callback
  175. auto callback = [this](PollService::Event event) {
  176. try {
  177. if (event == PollService::Event::Error)
  178. throw std::runtime_error("TCP connection failed");
  179. if (event == PollService::Event::Timeout)
  180. throw std::runtime_error("TCP connection timed out");
  181. if (event != PollService::Event::Out)
  182. return;
  183. int err = 0;
  184. socklen_t errlen = sizeof(err);
  185. if (::getsockopt(mSock, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&err),
  186. &errlen) != 0)
  187. throw std::runtime_error("Failed to get socket error code");
  188. if (err != 0) {
  189. std::ostringstream msg;
  190. msg << "TCP connection failed, errno=" << err;
  191. throw std::runtime_error(msg.str());
  192. }
  193. // Success
  194. PLOG_INFO << "TCP connected";
  195. changeState(State::Connected);
  196. setPoll(PollService::Direction::In);
  197. } catch (const std::exception &e) {
  198. PLOG_DEBUG << e.what();
  199. PollService::Instance().remove(mSock);
  200. ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
  201. }
  202. };
  203. const auto timeout = 10s;
  204. PollService::Instance().add(mSock, {PollService::Direction::Out, timeout, std::move(callback)});
  205. }
  206. void TcpTransport::createSocket(const struct sockaddr *addr, socklen_t addrlen) {
  207. try {
  208. char node[MAX_NUMERICNODE_LEN];
  209. char serv[MAX_NUMERICSERV_LEN];
  210. if (getnameinfo(addr, addrlen, node, MAX_NUMERICNODE_LEN, serv, MAX_NUMERICSERV_LEN,
  211. NI_NUMERICHOST | NI_NUMERICSERV) == 0) {
  212. PLOG_DEBUG << "Trying address " << node << ":" << serv;
  213. }
  214. PLOG_VERBOSE << "Creating TCP socket";
  215. // Create socket
  216. mSock = ::socket(addr->sa_family, SOCK_STREAM, IPPROTO_TCP);
  217. if (mSock == INVALID_SOCKET)
  218. throw std::runtime_error("TCP socket creation failed");
  219. // Configure socket
  220. configureSocket();
  221. // Initiate connection
  222. int ret = ::connect(mSock, addr, addrlen);
  223. if (ret < 0 && sockerrno != SEINPROGRESS && sockerrno != SEWOULDBLOCK) {
  224. std::ostringstream msg;
  225. msg << "TCP connection to " << node << ":" << serv << " failed, errno=" << sockerrno;
  226. throw std::runtime_error(msg.str());
  227. }
  228. } catch (...) {
  229. if (mSock != INVALID_SOCKET) {
  230. ::closesocket(mSock);
  231. mSock = INVALID_SOCKET;
  232. }
  233. throw;
  234. }
  235. }
  236. void TcpTransport::configureSocket() {
  237. // Set non-blocking
  238. ctl_t nbio = 1;
  239. if (::ioctlsocket(mSock, FIONBIO, &nbio) < 0)
  240. throw std::runtime_error("Failed to set socket non-blocking mode");
  241. // Disable the Nagle algorithm
  242. int nodelay = 1;
  243. ::setsockopt(mSock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<const char *>(&nodelay),
  244. sizeof(nodelay));
  245. #ifdef __APPLE__
  246. // MacOS lacks MSG_NOSIGNAL and requires SO_NOSIGPIPE instead
  247. const sockopt_t enabled = 1;
  248. if (::setsockopt(mSock, SOL_SOCKET, SO_NOSIGPIPE, &enabled, sizeof(enabled)) < 0)
  249. throw std::runtime_error("Failed to disable SIGPIPE for socket");
  250. #endif
  251. }
  252. void TcpTransport::setPoll(PollService::Direction direction) {
  253. PollService::Instance().add(
  254. mSock, {direction, direction == PollService::Direction::In ? mReadTimeout : nullopt,
  255. std::bind(&TcpTransport::process, this, _1)});
  256. }
  257. void TcpTransport::close() {
  258. std::lock_guard lock(mSendMutex);
  259. if (mSock != INVALID_SOCKET) {
  260. PLOG_DEBUG << "Closing TCP socket";
  261. PollService::Instance().remove(mSock);
  262. ::closesocket(mSock);
  263. mSock = INVALID_SOCKET;
  264. }
  265. changeState(State::Disconnected);
  266. }
  267. bool TcpTransport::trySendQueue() {
  268. // mSendMutex must be locked
  269. while (auto next = mSendQueue.peek()) {
  270. message_ptr message = std::move(*next);
  271. size_t size = message->size();
  272. if (!trySendMessage(message)) { // replaces message
  273. mSendQueue.exchange(message);
  274. updateBufferedAmount(-ptrdiff_t(size) + ptrdiff_t(message->size()));
  275. return false;
  276. }
  277. mSendQueue.pop();
  278. updateBufferedAmount(-ptrdiff_t(size));
  279. }
  280. return true;
  281. }
  282. bool TcpTransport::trySendMessage(message_ptr &message) {
  283. // mSendMutex must be locked
  284. auto data = reinterpret_cast<const char *>(message->data());
  285. auto size = message->size();
  286. while (size) {
  287. #if defined(__APPLE__) || defined(_WIN32)
  288. int flags = 0;
  289. #else
  290. int flags = MSG_NOSIGNAL;
  291. #endif
  292. int len = ::send(mSock, data, int(size), flags);
  293. if (len < 0) {
  294. if (sockerrno == SEAGAIN || sockerrno == SEWOULDBLOCK) {
  295. if (size < message->size())
  296. message = make_message(message->end() - size, message->end());
  297. return false;
  298. } else {
  299. PLOG_ERROR << "Connection closed, errno=" << sockerrno;
  300. throw std::runtime_error("Connection closed");
  301. }
  302. }
  303. data += len;
  304. size -= len;
  305. }
  306. message = nullptr;
  307. return true;
  308. }
  309. void TcpTransport::updateBufferedAmount(ptrdiff_t delta) {
  310. // Requires mSendMutex to be locked
  311. if (delta == 0)
  312. return;
  313. mBufferedAmount = size_t(std::max(ptrdiff_t(mBufferedAmount) + delta, ptrdiff_t(0)));
  314. // Synchronously call the buffered amount callback
  315. triggerBufferedAmount(mBufferedAmount);
  316. }
  317. void TcpTransport::triggerBufferedAmount(size_t amount) {
  318. try {
  319. mBufferedAmountCallback(amount);
  320. } catch (const std::exception &e) {
  321. PLOG_WARNING << "TCP buffered amount callback: " << e.what();
  322. }
  323. }
  324. void TcpTransport::process(PollService::Event event) {
  325. auto self = weak_from_this().lock();
  326. if (!self)
  327. return;
  328. try {
  329. switch (event) {
  330. case PollService::Event::Error: {
  331. PLOG_WARNING << "TCP connection terminated";
  332. break;
  333. }
  334. case PollService::Event::Timeout: {
  335. PLOG_VERBOSE << "TCP is idle";
  336. incoming(make_message(0));
  337. setPoll(PollService::Direction::In);
  338. return;
  339. }
  340. case PollService::Event::Out: {
  341. std::lock_guard lock(mSendMutex);
  342. if (trySendQueue())
  343. setPoll(PollService::Direction::In);
  344. return;
  345. }
  346. case PollService::Event::In: {
  347. const size_t bufferSize = 4096;
  348. char buffer[bufferSize];
  349. int len;
  350. while ((len = ::recv(mSock, buffer, bufferSize, 0)) > 0) {
  351. auto *b = reinterpret_cast<byte *>(buffer);
  352. incoming(make_message(b, b + len));
  353. }
  354. if (len == 0)
  355. break; // clean close
  356. if (sockerrno != SEAGAIN && sockerrno != SEWOULDBLOCK) {
  357. PLOG_WARNING << "TCP connection lost";
  358. break;
  359. }
  360. return;
  361. }
  362. default:
  363. // Ignore
  364. return;
  365. }
  366. } catch (const std::exception &e) {
  367. PLOG_ERROR << e.what();
  368. }
  369. PLOG_INFO << "TCP disconnected";
  370. PollService::Instance().remove(mSock);
  371. changeState(State::Disconnected);
  372. recv(nullptr);
  373. }
  374. } // namespace rtc::impl
  375. #endif