Ver código fonte

Refactored TCP transport connection process

Paul-Louis Ageneau 3 anos atrás
pai
commit
ad2473a902
2 arquivos alterados com 129 adições e 67 exclusões
  1. 120 63
      src/impl/tcptransport.cpp
  2. 9 4
      src/impl/tcptransport.hpp

+ 120 - 63
src/impl/tcptransport.cpp

@@ -18,6 +18,7 @@
 
 #include "tcptransport.hpp"
 #include "internals.hpp"
+#include "threadpool.hpp"
 
 #if RTC_ENABLE_WEBSOCKET
 
@@ -67,7 +68,10 @@ TcpTransport::TcpTransport(socket_t sock, state_callback callback)
 	mService = serv;
 }
 
-TcpTransport::~TcpTransport() { stop(); }
+TcpTransport::~TcpTransport() {
+	stop();
+	close();
+}
 
 void TcpTransport::start() {
 	Transport::start();
@@ -90,6 +94,7 @@ bool TcpTransport::stop() {
 
 bool TcpTransport::send(message_ptr message) {
 	std::lock_guard lock(mSendMutex);
+
 	if (state() != State::Connected)
 		throw std::runtime_error("Connection is not open");
 
@@ -119,87 +124,139 @@ bool TcpTransport::outgoing(message_ptr message) {
 	return false;
 }
 
+bool TcpTransport::isActive() const { return mIsActive; }
+
 string TcpTransport::remoteAddress() const { return mHostname + ':' + mService; }
 
 void TcpTransport::connect() {
+	if (state() == State::Connecting)
+		throw std::logic_error("TCP connection is already in progress");
+
+	if (state() == State::Connected)
+		throw std::logic_error("TCP is already connected");
+
 	PLOG_DEBUG << "Connecting to " << mHostname << ":" << mService;
 	changeState(State::Connecting);
 
-	// Resolve hostname
-	struct addrinfo hints = {};
-	hints.ai_family = AF_UNSPEC;
-	hints.ai_socktype = SOCK_STREAM;
-	hints.ai_protocol = IPPROTO_TCP;
-	hints.ai_flags = AI_ADDRCONFIG;
-
-	struct addrinfo *result = nullptr;
-	if (getaddrinfo(mHostname.c_str(), mService.c_str(), &hints, &result))
-		throw std::runtime_error("Resolution failed for \"" + mHostname + ":" + mService + "\"");
-
-	// Chain connection attempt to each address
-	auto attempt = [this, result](struct addrinfo *ai, auto recurse) {
-		if (!ai) {
-			PLOG_WARNING << "Connection to " << mHostname << ":" << mService << " failed";
+	ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::resolve, this));
+}
+
+void TcpTransport::resolve() {
+	std::lock_guard lock(mSendMutex);
+	mResolved.clear();
+
+	if (state() != State::Connecting)
+		return; // Cancelled
+
+	try {
+		PLOG_DEBUG << "Resolving " << mHostname << ":" << mService;
+
+		struct addrinfo hints = {};
+		hints.ai_family = AF_UNSPEC;
+		hints.ai_socktype = SOCK_STREAM;
+		hints.ai_protocol = IPPROTO_TCP;
+		hints.ai_flags = AI_ADDRCONFIG;
+
+		struct addrinfo *result = nullptr;
+		if (getaddrinfo(mHostname.c_str(), mService.c_str(), &hints, &result))
+			throw std::runtime_error("Resolution failed for \"" + mHostname + ":" + mService +
+			                         "\"");
+
+		try {
+			struct addrinfo *ai = result;
+			while (ai) {
+				struct sockaddr_storage addr;
+				std::memcpy(&addr, ai->ai_addr, ai->ai_addrlen);
+				mResolved.emplace_back(addr, socklen_t(ai->ai_addrlen));
+				ai = ai->ai_next;
+			}
+
+		} catch (...) {
 			freeaddrinfo(result);
-			changeState(State::Failed);
-			return;
+			throw;
 		}
 
+		freeaddrinfo(result);
+
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		changeState(State::Failed);
+		return;
+	}
+
+	ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
+}
+
+void TcpTransport::attempt() {
+	std::lock_guard lock(mSendMutex);
+
+	if (state() != State::Connecting)
+		return; // Cancelled
+
+	if (mSock == INVALID_SOCKET) {
+		::closesocket(mSock);
+		mSock = INVALID_SOCKET;
+	}
+
+	if (mResolved.empty()) {
+		PLOG_WARNING << "Connection to " << mHostname << ":" << mService << " failed";
+		changeState(State::Failed);
+		return;
+	}
+
+	try {
+		auto [addr, addrlen] = mResolved.front();
+		mResolved.pop_front();
+
+		createSocket(reinterpret_cast<const struct sockaddr *>(&addr), addrlen);
+
+	} catch (const std::runtime_error &e) {
+		PLOG_DEBUG << e.what();
+		ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
+		return;
+	}
+
+	// Poll out event callback
+	auto callback = [this](PollService::Event event) {
 		try {
-			prepare(ai->ai_addr, socklen_t(ai->ai_addrlen));
+			if (event == PollService::Event::Error)
+				throw std::runtime_error("TCP connection failed");
 
-		} catch (const std::runtime_error &e) {
-			PLOG_DEBUG << e.what();
-			recurse(ai->ai_next, recurse);
-		}
+			if (event == PollService::Event::Timeout)
+				throw std::runtime_error("TCP connection timed out");
 
-		// Poll out event callback
-		auto callback = [this, result, ai, recurse](PollService::Event event) mutable {
-			try {
-				if (event == PollService::Event::Error)
-					throw std::runtime_error("TCP connection failed");
-
-				if (event == PollService::Event::Timeout)
-					throw std::runtime_error("TCP connection timed out");
-
-				if (event != PollService::Event::Out)
-					return;
-
-				int err = 0;
-				socklen_t errlen = sizeof(err);
-				if (::getsockopt(mSock, SOL_SOCKET, SO_ERROR, (char *)&err, &errlen) != 0)
-					throw std::runtime_error("Failed to get socket error code");
-
-				if (err != 0) {
-					std::ostringstream msg;
-					msg << "TCP connection failed, errno=" << err;
-					throw std::runtime_error(msg.str());
-				}
-
-				// Success
-				PLOG_INFO << "TCP connected";
-				freeaddrinfo(result);
-				changeState(State::Connected);
-				setPoll(PollService::Direction::In);
+			if (event != PollService::Event::Out)
+				return;
 
-			} catch (const std::runtime_error &e) {
-				PLOG_DEBUG << e.what();
-				PollService::Instance().remove(mSock);
-				::closesocket(mSock);
-				mSock = INVALID_SOCKET;
-				recurse(ai->ai_next, recurse);
+			int err = 0;
+			socklen_t errlen = sizeof(err);
+			if (::getsockopt(mSock, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&err),
+			                 &errlen) != 0)
+				throw std::runtime_error("Failed to get socket error code");
+
+			if (err != 0) {
+				std::ostringstream msg;
+				msg << "TCP connection failed, errno=" << err;
+				throw std::runtime_error(msg.str());
 			}
-		};
 
-		const auto timeout = 10s;
-		PollService::Instance().add(mSock,
-		                            {PollService::Direction::Out, timeout, std::move(callback)});
+			// Success
+			PLOG_INFO << "TCP connected";
+			changeState(State::Connected);
+			setPoll(PollService::Direction::In);
+
+		} catch (const std::exception &e) {
+			PLOG_DEBUG << e.what();
+			PollService::Instance().remove(mSock);
+			ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
+		}
 	};
 
-	attempt(result, attempt);
+	const auto timeout = 10s;
+	PollService::Instance().add(mSock, {PollService::Direction::Out, timeout, std::move(callback)});
 }
 
-void TcpTransport::prepare(const sockaddr *addr, socklen_t addrlen) {
+void TcpTransport::createSocket(const struct sockaddr *addr, socklen_t addrlen) {
 	try {
 		char node[MAX_NUMERICNODE_LEN];
 		char serv[MAX_NUMERICSERV_LEN];

+ 9 - 4
src/impl/tcptransport.hpp

@@ -27,11 +27,13 @@
 
 #if RTC_ENABLE_WEBSOCKET
 
+#include <list>
 #include <mutex>
+#include <tuple>
 
 namespace rtc::impl {
 
-class TcpTransport : public Transport {
+class TcpTransport final : public Transport, public std::enable_shared_from_this<TcpTransport> {
 public:
 	TcpTransport(string hostname, string service, state_callback callback); // active
 	TcpTransport(socket_t sock, state_callback callback);                   // passive
@@ -44,13 +46,14 @@ public:
 	void incoming(message_ptr message) override;
 	bool outgoing(message_ptr message) override;
 
-	bool isActive() const { return mIsActive; }
-
+	bool isActive() const;
 	string remoteAddress() const;
 
 private:
 	void connect();
-	void prepare(const sockaddr *addr, socklen_t addrlen);
+	void resolve();
+	void attempt();
+	void createSocket(const struct sockaddr *addr, socklen_t addrlen);
 	void configureSocket();
 	void setPoll(PollService::Direction direction);
 	void close();
@@ -63,6 +66,8 @@ private:
 	const bool mIsActive;
 	string mHostname, mService;
 
+	std::list<std::tuple<struct sockaddr_storage, socklen_t>> mResolved;
+
 	socket_t mSock;
 	Queue<message_ptr> mSendQueue;
 	std::mutex mSendMutex;