Browse Source

Merge branch 'v0.17'

Paul-Louis Ageneau 3 years ago
parent
commit
16f87814bc
3 changed files with 131 additions and 68 deletions
  1. 1 1
      examples/media-sender/main.cpp
  2. 120 63
      src/impl/tcptransport.cpp
  3. 10 4
      src/impl/tcptransport.hpp

+ 1 - 1
examples/media-sender/main.cpp

@@ -60,7 +60,7 @@ int main() {
 		});
 		});
 
 
 		SOCKET sock = socket(AF_INET, SOCK_DGRAM, 0);
 		SOCKET sock = socket(AF_INET, SOCK_DGRAM, 0);
-		sockaddr_in addr = {};
+		struct sockaddr_in addr = {};
 		addr.sin_family = AF_INET;
 		addr.sin_family = AF_INET;
 		addr.sin_addr.s_addr = inet_addr("127.0.0.1");
 		addr.sin_addr.s_addr = inet_addr("127.0.0.1");
 		addr.sin_port = htons(6000);
 		addr.sin_port = htons(6000);

+ 120 - 63
src/impl/tcptransport.cpp

@@ -18,6 +18,7 @@
 
 
 #include "tcptransport.hpp"
 #include "tcptransport.hpp"
 #include "internals.hpp"
 #include "internals.hpp"
+#include "threadpool.hpp"
 
 
 #if RTC_ENABLE_WEBSOCKET
 #if RTC_ENABLE_WEBSOCKET
 
 
@@ -67,7 +68,10 @@ TcpTransport::TcpTransport(socket_t sock, state_callback callback)
 	mService = serv;
 	mService = serv;
 }
 }
 
 
-TcpTransport::~TcpTransport() { stop(); }
+TcpTransport::~TcpTransport() {
+	stop();
+	close();
+}
 
 
 void TcpTransport::onBufferedAmount(amount_callback callback) {
 void TcpTransport::onBufferedAmount(amount_callback callback) {
 	mBufferedAmountCallback = std::move(callback);
 	mBufferedAmountCallback = std::move(callback);
@@ -98,6 +102,7 @@ bool TcpTransport::stop() {
 
 
 bool TcpTransport::send(message_ptr message) {
 bool TcpTransport::send(message_ptr message) {
 	std::lock_guard lock(mSendMutex);
 	std::lock_guard lock(mSendMutex);
+
 	if (state() != State::Connected)
 	if (state() != State::Connected)
 		throw std::runtime_error("Connection is not open");
 		throw std::runtime_error("Connection is not open");
 
 
@@ -128,87 +133,139 @@ bool TcpTransport::outgoing(message_ptr message) {
 	return false;
 	return false;
 }
 }
 
 
+bool TcpTransport::isActive() const { return mIsActive; }
+
 string TcpTransport::remoteAddress() const { return mHostname + ':' + mService; }
 string TcpTransport::remoteAddress() const { return mHostname + ':' + mService; }
 
 
 void TcpTransport::connect() {
 void TcpTransport::connect() {
+	if (state() == State::Connecting)
+		throw std::logic_error("TCP connection is already in progress");
+
+	if (state() == State::Connected)
+		throw std::logic_error("TCP is already connected");
+
 	PLOG_DEBUG << "Connecting to " << mHostname << ":" << mService;
 	PLOG_DEBUG << "Connecting to " << mHostname << ":" << mService;
 	changeState(State::Connecting);
 	changeState(State::Connecting);
 
 
-	// Resolve hostname
-	struct addrinfo hints = {};
-	hints.ai_family = AF_UNSPEC;
-	hints.ai_socktype = SOCK_STREAM;
-	hints.ai_protocol = IPPROTO_TCP;
-	hints.ai_flags = AI_ADDRCONFIG;
-
-	struct addrinfo *result = nullptr;
-	if (getaddrinfo(mHostname.c_str(), mService.c_str(), &hints, &result))
-		throw std::runtime_error("Resolution failed for \"" + mHostname + ":" + mService + "\"");
-
-	// Chain connection attempt to each address
-	auto attempt = [this, result](struct addrinfo *ai, auto recurse) {
-		if (!ai) {
-			PLOG_WARNING << "Connection to " << mHostname << ":" << mService << " failed";
+	ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::resolve, this));
+}
+
+void TcpTransport::resolve() {
+	std::lock_guard lock(mSendMutex);
+	mResolved.clear();
+
+	if (state() != State::Connecting)
+		return; // Cancelled
+
+	try {
+		PLOG_DEBUG << "Resolving " << mHostname << ":" << mService;
+
+		struct addrinfo hints = {};
+		hints.ai_family = AF_UNSPEC;
+		hints.ai_socktype = SOCK_STREAM;
+		hints.ai_protocol = IPPROTO_TCP;
+		hints.ai_flags = AI_ADDRCONFIG;
+
+		struct addrinfo *result = nullptr;
+		if (getaddrinfo(mHostname.c_str(), mService.c_str(), &hints, &result))
+			throw std::runtime_error("Resolution failed for \"" + mHostname + ":" + mService +
+			                         "\"");
+
+		try {
+			struct addrinfo *ai = result;
+			while (ai) {
+				struct sockaddr_storage addr;
+				std::memcpy(&addr, ai->ai_addr, ai->ai_addrlen);
+				mResolved.emplace_back(addr, socklen_t(ai->ai_addrlen));
+				ai = ai->ai_next;
+			}
+
+		} catch (...) {
 			freeaddrinfo(result);
 			freeaddrinfo(result);
-			changeState(State::Failed);
-			return;
+			throw;
 		}
 		}
 
 
+		freeaddrinfo(result);
+
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		changeState(State::Failed);
+		return;
+	}
+
+	ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
+}
+
+void TcpTransport::attempt() {
+	std::lock_guard lock(mSendMutex);
+
+	if (state() != State::Connecting)
+		return; // Cancelled
+
+	if (mSock == INVALID_SOCKET) {
+		::closesocket(mSock);
+		mSock = INVALID_SOCKET;
+	}
+
+	if (mResolved.empty()) {
+		PLOG_WARNING << "Connection to " << mHostname << ":" << mService << " failed";
+		changeState(State::Failed);
+		return;
+	}
+
+	try {
+		auto [addr, addrlen] = mResolved.front();
+		mResolved.pop_front();
+
+		createSocket(reinterpret_cast<const struct sockaddr *>(&addr), addrlen);
+
+	} catch (const std::runtime_error &e) {
+		PLOG_DEBUG << e.what();
+		ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
+		return;
+	}
+
+	// Poll out event callback
+	auto callback = [this](PollService::Event event) {
 		try {
 		try {
-			prepare(ai->ai_addr, socklen_t(ai->ai_addrlen));
+			if (event == PollService::Event::Error)
+				throw std::runtime_error("TCP connection failed");
 
 
-		} catch (const std::runtime_error &e) {
-			PLOG_DEBUG << e.what();
-			recurse(ai->ai_next, recurse);
-		}
+			if (event == PollService::Event::Timeout)
+				throw std::runtime_error("TCP connection timed out");
 
 
-		// Poll out event callback
-		auto callback = [this, result, ai, recurse](PollService::Event event) mutable {
-			try {
-				if (event == PollService::Event::Error)
-					throw std::runtime_error("TCP connection failed");
-
-				if (event == PollService::Event::Timeout)
-					throw std::runtime_error("TCP connection timed out");
-
-				if (event != PollService::Event::Out)
-					return;
-
-				int err = 0;
-				socklen_t errlen = sizeof(err);
-				if (::getsockopt(mSock, SOL_SOCKET, SO_ERROR, (char *)&err, &errlen) != 0)
-					throw std::runtime_error("Failed to get socket error code");
-
-				if (err != 0) {
-					std::ostringstream msg;
-					msg << "TCP connection failed, errno=" << err;
-					throw std::runtime_error(msg.str());
-				}
-
-				// Success
-				PLOG_INFO << "TCP connected";
-				freeaddrinfo(result);
-				changeState(State::Connected);
-				setPoll(PollService::Direction::In);
+			if (event != PollService::Event::Out)
+				return;
+
+			int err = 0;
+			socklen_t errlen = sizeof(err);
+			if (::getsockopt(mSock, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&err),
+			                 &errlen) != 0)
+				throw std::runtime_error("Failed to get socket error code");
 
 
-			} catch (const std::runtime_error &e) {
-				PLOG_DEBUG << e.what();
-				PollService::Instance().remove(mSock);
-				::closesocket(mSock);
-				mSock = INVALID_SOCKET;
-				recurse(ai->ai_next, recurse);
+			if (err != 0) {
+				std::ostringstream msg;
+				msg << "TCP connection failed, errno=" << err;
+				throw std::runtime_error(msg.str());
 			}
 			}
-		};
 
 
-		const auto timeout = 10s;
-		PollService::Instance().add(mSock,
-		                            {PollService::Direction::Out, timeout, std::move(callback)});
+			// Success
+			PLOG_INFO << "TCP connected";
+			changeState(State::Connected);
+			setPoll(PollService::Direction::In);
+
+		} catch (const std::exception &e) {
+			PLOG_DEBUG << e.what();
+			PollService::Instance().remove(mSock);
+			ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
+		}
 	};
 	};
 
 
-	attempt(result, attempt);
+	const auto timeout = 10s;
+	PollService::Instance().add(mSock, {PollService::Direction::Out, timeout, std::move(callback)});
 }
 }
 
 
-void TcpTransport::prepare(const sockaddr *addr, socklen_t addrlen) {
+void TcpTransport::createSocket(const struct sockaddr *addr, socklen_t addrlen) {
 	try {
 	try {
 		char node[MAX_NUMERICNODE_LEN];
 		char node[MAX_NUMERICNODE_LEN];
 		char serv[MAX_NUMERICSERV_LEN];
 		char serv[MAX_NUMERICSERV_LEN];

+ 10 - 4
src/impl/tcptransport.hpp

@@ -27,12 +27,14 @@
 
 
 #if RTC_ENABLE_WEBSOCKET
 #if RTC_ENABLE_WEBSOCKET
 
 
-#include <mutex>
 #include <chrono>
 #include <chrono>
+#include <list>
+#include <mutex>
+#include <tuple>
 
 
 namespace rtc::impl {
 namespace rtc::impl {
 
 
-class TcpTransport : public Transport {
+class TcpTransport final : public Transport, public std::enable_shared_from_this<TcpTransport> {
 public:
 public:
 	using amount_callback = std::function<void(size_t amount)>;
 	using amount_callback = std::function<void(size_t amount)>;
 
 
@@ -50,12 +52,14 @@ public:
 	void incoming(message_ptr message) override;
 	void incoming(message_ptr message) override;
 	bool outgoing(message_ptr message) override;
 	bool outgoing(message_ptr message) override;
 
 
-	bool isActive() const { return mIsActive; }
+	bool isActive() const;
 	string remoteAddress() const;
 	string remoteAddress() const;
 
 
 private:
 private:
 	void connect();
 	void connect();
-	void prepare(const sockaddr *addr, socklen_t addrlen);
+	void resolve();
+	void attempt();
+	void createSocket(const struct sockaddr *addr, socklen_t addrlen);
 	void configureSocket();
 	void configureSocket();
 	void setPoll(PollService::Direction direction);
 	void setPoll(PollService::Direction direction);
 	void close();
 	void close();
@@ -72,6 +76,8 @@ private:
 	amount_callback mBufferedAmountCallback;
 	amount_callback mBufferedAmountCallback;
 	optional<std::chrono::milliseconds> mReadTimeout;
 	optional<std::chrono::milliseconds> mReadTimeout;
 
 
+	std::list<std::tuple<struct sockaddr_storage, socklen_t>> mResolved;
+
 	socket_t mSock;
 	socket_t mSock;
 	Queue<message_ptr> mSendQueue;
 	Queue<message_ptr> mSendQueue;
 	size_t mBufferedAmount = 0;
 	size_t mBufferedAmount = 0;