Browse Source

Merge pull request #807 from paullouisageneau/dtls-threadpool

Move DTLS and TLS transports to thread pool
Paul-Louis Ageneau 2 years ago
parent
commit
b1ec5747bd

+ 1 - 1
src/impl/datachannel.cpp

@@ -108,7 +108,7 @@ void DataChannel::remoteClose() {
 }
 
 optional<message_variant> DataChannel::receive() {
-	auto next = mRecvQueue.tryPop();
+	auto next = mRecvQueue.pop();
 	return next ? std::make_optional(to_variant(std::move(**next))) : nullopt;
 }
 

+ 185 - 150
src/impl/dtlstransport.cpp

@@ -9,6 +9,7 @@
 #include "dtlstransport.hpp"
 #include "icetransport.hpp"
 #include "internals.hpp"
+#include "threadpool.hpp"
 
 #include <algorithm>
 #include <chrono>
@@ -27,6 +28,16 @@ using namespace std::chrono;
 
 namespace rtc::impl {
 
+void DtlsTransport::enqueueRecv() {
+	if (mPendingRecvCount > 0)
+		return;
+
+	if (auto shared_this = weak_from_this().lock()) {
+		++mPendingRecvCount;
+		ThreadPool::Instance().enqueue(&DtlsTransport::doRecv, std::move(shared_this));
+	}
+}
+
 #if USE_GNUTLS
 
 void DtlsTransport::Init() {
@@ -50,7 +61,8 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 	gnutls_certificate_credentials_t creds = mCertificate->credentials();
 	gnutls_certificate_set_verify_function(creds, CertificateCallback);
 
-	unsigned int flags = GNUTLS_DATAGRAM | (mIsClient ? GNUTLS_CLIENT : GNUTLS_SERVER);
+	unsigned int flags =
+	    GNUTLS_DATAGRAM | GNUTLS_NONBLOCK | (mIsClient ? GNUTLS_CLIENT : GNUTLS_SERVER);
 	gnutls::check(gnutls_init(&mSession, flags));
 
 	try {
@@ -98,22 +110,22 @@ DtlsTransport::~DtlsTransport() {
 }
 
 void DtlsTransport::start() {
-	if(mStarted.exchange(true))
-		return;
-
-	PLOG_DEBUG << "Starting DTLS recv thread";
+	PLOG_DEBUG << "Starting DTLS transport";
 	registerIncoming();
-	mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
+	changeState(State::Connecting);
+
+	size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
+	gnutls_dtls_set_mtu(mSession, static_cast<unsigned int>(mtu));
+	PLOG_VERBOSE << "DTLS MTU set to " << mtu;
+
+	enqueueRecv(); // to initiate the handshake
 }
 
 void DtlsTransport::stop() {
-	if(!mStarted.exchange(false))
-		return;
-
-	PLOG_DEBUG << "Stopping DTLS recv thread";
+	PLOG_DEBUG << "Stopping DTLS transport";
 	unregisterIncoming();
 	mIncomingQueue.stop();
-	mRecvThread.join();
+	enqueueRecv();
 }
 
 bool DtlsTransport::send(message_ptr message) {
@@ -122,7 +134,6 @@ bool DtlsTransport::send(message_ptr message) {
 
 	PLOG_VERBOSE << "Send size=" << message->size();
 
-
 	ssize_t ret;
 	do {
 		std::lock_guard lock(mSendMutex);
@@ -147,6 +158,7 @@ void DtlsTransport::incoming(message_ptr message) {
 
 	PLOG_VERBOSE << "Incoming size=" << message->size();
 	mIncomingQueue.push(message);
+	enqueueRecv();
 }
 
 bool DtlsTransport::outgoing(message_ptr message) {
@@ -166,79 +178,85 @@ void DtlsTransport::postHandshake() {
 	// Dummy
 }
 
-void DtlsTransport::runRecvLoop() {
-	const size_t bufferSize = 4096;
+void DtlsTransport::doRecv() {
+	std::lock_guard lock(mRecvMutex);
+	--mPendingRecvCount;
 
-	// Handshake loop
-	try {
-		changeState(State::Connecting);
-
-		size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
-		gnutls_dtls_set_mtu(mSession, static_cast<unsigned int>(mtu));
-		PLOG_VERBOSE << "SSL MTU set to " << mtu;
+	if (state() != State::Connecting && state() != State::Connected)
+		return;
 
-		int ret;
-		do {
-			ret = gnutls_handshake(mSession);
+	try {
+		const size_t bufferSize = 4096;
+		char buffer[bufferSize];
 
-			if (ret == GNUTLS_E_LARGE_PACKET)
-				throw std::runtime_error("MTU is too low");
+		// Handle handshake if connecting
+		if (state() == State::Connecting) {
+			int ret;
+			do {
+				ret = gnutls_handshake(mSession);
+
+				if (ret == GNUTLS_E_AGAIN) {
+					// Schedule next call on timeout and return
+					auto timeout = milliseconds(gnutls_dtls_get_timeout(mSession));
+					ThreadPool::Instance().schedule(timeout, [weak_this = weak_from_this()]() {
+						if (auto locked = weak_this.lock())
+							locked->doRecv();
+					});
+					return;
+				}
 
-		} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN ||
-		         !gnutls::check(ret, "DTLS handshake failed"));
+				if (ret == GNUTLS_E_LARGE_PACKET) {
+					throw std::runtime_error("MTU is too low");
+				}
 
-		// RFC 8261: DTLS MUST support sending messages larger than the current path MTU
-		// See https://www.rfc-editor.org/rfc/rfc8261.html#section-5
-		gnutls_dtls_set_mtu(mSession, bufferSize + 1);
+			} while (!gnutls::check(ret, "DTLS handshake failed")); // Re-call on non-fatal error
 
-	} catch (const std::exception &e) {
-		PLOG_ERROR << "DTLS handshake: " << e.what();
-		changeState(State::Failed);
-		return;
-	}
+			// RFC 8261: DTLS MUST support sending messages larger than the current path MTU
+			// See https://www.rfc-editor.org/rfc/rfc8261.html#section-5
+			gnutls_dtls_set_mtu(mSession, bufferSize + 1);
 
-	// Receive loop
-	try {
-		PLOG_INFO << "DTLS handshake finished";
-		postHandshake();
-		changeState(State::Connected);
+			PLOG_INFO << "DTLS handshake finished";
+			changeState(State::Connected);
+			postHandshake();
+		}
 
-		char buffer[bufferSize];
+		if (state() == State::Connected) {
+			while (true) {
+				ssize_t ret = gnutls_record_recv(mSession, buffer, bufferSize);
 
-		while (true) {
-			ssize_t ret;
-			do {
-				ret = gnutls_record_recv(mSession, buffer, bufferSize);
-			} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
-
-			// RFC 8827: Implementations MUST NOT implement DTLS renegotiation and MUST reject it
-			// with a "no_renegotiation" alert if offered.
-			// See https://www.rfc-editor.org/rfc/rfc8827.html#section-6.5
-			if (ret == GNUTLS_E_REHANDSHAKE) {
-				do {
-					std::lock_guard lock(mSendMutex);
-					ret = gnutls_alert_send(mSession, GNUTLS_AL_WARNING, GNUTLS_A_NO_RENEGOTIATION);
-				} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
-				continue;
-			}
+				if (ret == GNUTLS_E_AGAIN) {
+					return;
+				}
 
-			// Consider premature termination as remote closing
-			if (ret == GNUTLS_E_PREMATURE_TERMINATION) {
-				PLOG_DEBUG << "DTLS connection terminated";
-				break;
-			}
+				// RFC 8827: Implementations MUST NOT implement DTLS renegotiation and MUST reject
+				// it with a "no_renegotiation" alert if offered. See
+				// https://www.rfc-editor.org/rfc/rfc8827.html#section-6.5
+				if (ret == GNUTLS_E_REHANDSHAKE) {
+					do {
+						std::lock_guard lock(mSendMutex);
+						ret = gnutls_alert_send(mSession, GNUTLS_AL_WARNING,
+						                        GNUTLS_A_NO_RENEGOTIATION);
+					} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
+					continue;
+				}
 
-			if (gnutls::check(ret)) {
-				if (ret == 0) {
-					// Closed
-					PLOG_DEBUG << "DTLS connection cleanly closed";
+				// Consider premature termination as remote closing
+				if (ret == GNUTLS_E_PREMATURE_TERMINATION) {
+					PLOG_DEBUG << "DTLS connection terminated";
 					break;
 				}
-				auto *b = reinterpret_cast<byte *>(buffer);
-				recv(make_message(b, b + ret));
+
+				if (gnutls::check(ret)) {
+					if (ret == 0) {
+						// Closed
+						PLOG_DEBUG << "DTLS connection cleanly closed";
+						break;
+					}
+					auto *b = reinterpret_cast<byte *>(buffer);
+					recv(make_message(b, b + ret));
+				}
 			}
 		}
-
 	} catch (const std::exception &e) {
 		PLOG_ERROR << "DTLS recv: " << e.what();
 	}
@@ -303,7 +321,13 @@ ssize_t DtlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *dat
 ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
 	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
 	try {
-		while (auto next = t->mIncomingQueue.pop()) {
+		while (t->mIncomingQueue.running()) {
+			auto next = t->mIncomingQueue.pop();
+			if (!next) {
+				gnutls_transport_set_errno(t->mSession, EAGAIN);
+				return -1;
+			}
+
 			message_ptr message = std::move(*next);
 			if (t->demuxMessage(message))
 				continue;
@@ -325,12 +349,10 @@ ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size
 	}
 }
 
-int DtlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
+int DtlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int /* ms */) {
 	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
 	try {
-		bool isReadable = t->mIncomingQueue.wait(
-		    ms != GNUTLS_INDEFINITE_TIMEOUT ? std::make_optional(milliseconds(ms)) : nullopt);
-		return isReadable ? 1 : 0;
+		return !t->mIncomingQueue.empty() ? 1 : 0;
 
 	} catch (const std::exception &e) {
 		PLOG_WARNING << e.what();
@@ -438,11 +460,13 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 		// See https://www.rfc-editor.org/rfc/rfc8827.html#section-6.5 Warning:
 		// SSL_set_tlsext_use_srtp() returns 0 on success and 1 on error
 		// Try to use GCM suite
-		if (SSL_set_tlsext_use_srtp(mSsl, "SRTP_AEAD_AES_256_GCM:SRTP_AEAD_AES_128_GCM:SRTP_AES128_CM_SHA1_80")) {
+		if (SSL_set_tlsext_use_srtp(
+		        mSsl, "SRTP_AEAD_AES_256_GCM:SRTP_AEAD_AES_128_GCM:SRTP_AES128_CM_SHA1_80")) {
 			if (SSL_set_tlsext_use_srtp(mSsl, "SRTP_AES128_CM_SHA1_80"))
 				throw std::runtime_error("Failed to set SRTP profile: " +
-							openssl::error_string(ERR_get_error()));
+				                         openssl::error_string(ERR_get_error()));
 		}
+
 	} catch (...) {
 		if (mSsl)
 			SSL_free(mSsl);
@@ -465,23 +489,26 @@ DtlsTransport::~DtlsTransport() {
 }
 
 void DtlsTransport::start() {
-	if(mStarted.exchange(true))
-		return;
-
-	PLOG_DEBUG << "Starting DTLS recv thread";
+	PLOG_DEBUG << "Starting DTLS transport";
 	registerIncoming();
-	mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
+	changeState(State::Connecting);
+
+	size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
+	SSL_set_mtu(mSsl, static_cast<unsigned int>(mtu));
+	PLOG_VERBOSE << "DTLS MTU set to " << mtu;
+
+	// Initiate the handshake
+	int ret = SSL_do_handshake(mSsl);
+	openssl::check(mSsl, ret, "Handshake initiation failed");
+
+	handleTimeout();
 }
 
 void DtlsTransport::stop() {
-	if(!mStarted.exchange(false))
-		return;
-
-	PLOG_DEBUG << "Stopping DTLS recv thread";
+	PLOG_DEBUG << "Stopping DTLS transport";
 	unregisterIncoming();
 	mIncomingQueue.stop();
-	mRecvThread.join();
-	SSL_shutdown(mSsl);
+	enqueueRecv();
 }
 
 bool DtlsTransport::send(message_ptr message) {
@@ -501,11 +528,13 @@ bool DtlsTransport::send(message_ptr message) {
 void DtlsTransport::incoming(message_ptr message) {
 	if (!message) {
 		mIncomingQueue.stop();
+		enqueueRecv();
 		return;
 	}
 
 	PLOG_VERBOSE << "Incoming size=" << message->size();
 	mIncomingQueue.push(message);
+	enqueueRecv();
 }
 
 bool DtlsTransport::outgoing(message_ptr message) {
@@ -525,86 +554,65 @@ void DtlsTransport::postHandshake() {
 	// Dummy
 }
 
-void DtlsTransport::runRecvLoop() {
-	const size_t bufferSize = 4096;
-	try {
-		changeState(State::Connecting);
-
-		size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
-		SSL_set_mtu(mSsl, static_cast<unsigned int>(mtu));
-		PLOG_VERBOSE << "SSL MTU set to " << mtu;
+void DtlsTransport::doRecv() {
+	std::lock_guard lock(mRecvMutex);
+	--mPendingRecvCount;
 
-		// Initiate the handshake
-		int ret = SSL_do_handshake(mSsl);
-		openssl::check(mSsl, ret, "Handshake failed");
+	if (state() != State::Connecting && state() != State::Connected)
+		return;
 
+	try {
+		const size_t bufferSize = 4096;
 		byte buffer[bufferSize];
+
+		// Process pending messages
 		while (mIncomingQueue.running()) {
-			// Process pending messages
-			while (auto next = mIncomingQueue.tryPop()) {
-				message_ptr message = std::move(*next);
-				if (demuxMessage(message))
-					continue;
+			auto next = mIncomingQueue.pop();
+			if (!next) {
+				// No more messages pending, handle timeout if connecting
+				if (state() == State::Connecting)
+					handleTimeout();
 
-				BIO_write(mInBio, message->data(), int(message->size()));
+				return;
+			}
 
-				if (state() == State::Connecting) {
-					// Continue the handshake
-					ret = SSL_do_handshake(mSsl);
-					if (!openssl::check(mSsl, ret, "Handshake failed"))
-						break;
+			message_ptr message = std::move(*next);
+			if (demuxMessage(message))
+				continue;
 
-					if (SSL_is_init_finished(mSsl)) {
-						// RFC 8261: DTLS MUST support sending messages larger than the current path
-						// MTU See https://www.rfc-editor.org/rfc/rfc8261.html#section-5
-						SSL_set_mtu(mSsl, bufferSize + 1);
+			BIO_write(mInBio, message->data(), int(message->size()));
 
-						PLOG_INFO << "DTLS handshake finished";
-						postHandshake();
-						changeState(State::Connected);
-					}
-				} else {
-					ret = SSL_read(mSsl, buffer, bufferSize);
-					if (!openssl::check(mSsl, ret))
-						break;
+			if (state() == State::Connecting) {
+				// Continue the handshake
+				int ret = SSL_do_handshake(mSsl);
+				if (!openssl::check(mSsl, ret, "Handshake failed"))
+					break;
 
-					if (ret > 0)
-						recv(make_message(buffer, buffer + ret));
-				}
-			}
+				if (SSL_is_init_finished(mSsl)) {
+					// RFC 8261: DTLS MUST support sending messages larger than the current path
+					// MTU See https://www.rfc-editor.org/rfc/rfc8261.html#section-5
+					SSL_set_mtu(mSsl, bufferSize + 1);
 
-			// No more messages pending, retransmit and rearm timeout if connecting
-			optional<milliseconds> duration;
-			if (state() == State::Connecting) {
-				// Warning: This function breaks the usual return value convention
-				ret = DTLSv1_handle_timeout(mSsl);
-				if (ret < 0) {
-					throw std::runtime_error("Handshake timeout"); // write BIO can't fail
-				} else if (ret > 0) {
-					LOG_VERBOSE << "OpenSSL did DTLS retransmit";
+					PLOG_INFO << "DTLS handshake finished";
+					postHandshake();
+					changeState(State::Connected);
 				}
+			} else {
+				int ret = SSL_read(mSsl, buffer, bufferSize);
+				if (!openssl::check(mSsl, ret))
+					break;
 
-				struct timeval timeout = {};
-				if (state() == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) {
-					duration = milliseconds(timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
-					// Also handle handshake timeout manually because OpenSSL actually doesn't...
-					// OpenSSL backs off exponentially in base 2 starting from the recommended 1s
-					// so this allows for 5 retransmissions and fails after roughly 30s.
-					if (duration > 30s) {
-						throw std::runtime_error("Handshake timeout");
-					} else {
-						LOG_VERBOSE << "OpenSSL DTLS retransmit timeout is " << duration->count()
-						            << "ms";
-					}
-				}
+				if (ret > 0)
+					recv(make_message(buffer, buffer + ret));
 			}
-
-			mIncomingQueue.wait(duration);
 		}
+
 	} catch (const std::exception &e) {
 		PLOG_ERROR << "DTLS recv: " << e.what();
 	}
 
+	SSL_shutdown(mSsl);
+
 	if (state() == State::Connected) {
 		PLOG_INFO << "DTLS closed";
 		changeState(State::Disconnected);
@@ -615,6 +623,33 @@ void DtlsTransport::runRecvLoop() {
 	}
 }
 
+void DtlsTransport::handleTimeout() {
+	// Warning: This function breaks the usual return value convention
+	int ret = DTLSv1_handle_timeout(mSsl);
+	if (ret < 0) {
+		throw std::runtime_error("Handshake timeout"); // write BIO can't fail
+	} else if (ret > 0) {
+		LOG_VERBOSE << "DTLS retransmit done";
+	}
+
+	struct timeval tv = {};
+	if (DTLSv1_get_timeout(mSsl, &tv)) {
+		auto timeout = milliseconds(tv.tv_sec * 1000 + tv.tv_usec / 1000);
+		// Also handle handshake timeout manually because OpenSSL actually
+		// doesn't... OpenSSL backs off exponentially in base 2 starting from the
+		// recommended 1s so this allows for 5 retransmissions and fails after
+		// roughly 30s.
+		if (timeout > 30s)
+			throw std::runtime_error("Handshake timeout");
+
+		LOG_VERBOSE << "DTLS retransmit timeout is " << timeout.count() << "ms";
+		ThreadPool::Instance().schedule(timeout, [weak_this = weak_from_this()]() {
+			if (auto locked = weak_this.lock())
+				locked->doRecv();
+		});
+	}
+}
+
 int DtlsTransport::CertificateCallback(int /*preverify_ok*/, X509_STORE_CTX *ctx) {
 	SSL *ssl =
 	    static_cast<SSL *>(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));

+ 7 - 4
src/impl/dtlstransport.hpp

@@ -25,7 +25,7 @@ namespace rtc::impl {
 
 class IceTransport;
 
-class DtlsTransport : public Transport {
+class DtlsTransport : public Transport, public std::enable_shared_from_this<DtlsTransport> {
 public:
 	static void Init();
 	static void Cleanup();
@@ -48,7 +48,8 @@ protected:
 	virtual bool demuxMessage(message_ptr message);
 	virtual void postHandshake();
 
-	void runRecvLoop();
+	void enqueueRecv();
+	void doRecv();
 
 	const optional<size_t> mMtu;
 	const certificate_ptr mCertificate;
@@ -56,8 +57,8 @@ protected:
 	const bool mIsClient;
 
 	Queue<message_ptr> mIncomingQueue;
-	std::thread mRecvThread;
-	std::atomic<bool> mStarted = false;
+	std::atomic<int> mPendingRecvCount = 0;
+	std::mutex mRecvMutex;
 	std::atomic<unsigned int> mCurrentDscp = 0;
 	std::atomic<bool> mOutgoingResult = true;
 
@@ -74,6 +75,8 @@ protected:
 	SSL *mSsl = NULL;
 	BIO *mInBio, *mOutBio;
 
+	void handleTimeout();
+
 	static BIO_METHOD *BioMethods;
 	static int TransportExIndex;
 	static std::mutex GlobalMutex;

+ 8 - 2
src/impl/init.cpp

@@ -11,9 +11,9 @@
 
 #include "certificate.hpp"
 #include "dtlstransport.hpp"
+#include "icetransport.hpp"
 #include "pollservice.hpp"
 #include "sctptransport.hpp"
-#include "icetransport.hpp"
 #include "threadpool.hpp"
 #include "tls.hpp"
 
@@ -29,6 +29,8 @@
 #include <winsock2.h>
 #endif
 
+#include <thread>
+
 namespace rtc::impl {
 
 struct Init::TokenPayload {
@@ -115,7 +117,11 @@ void Init::doInit() {
 		throw std::runtime_error("WSAStartup failed, error=" + std::to_string(WSAGetLastError()));
 #endif
 
-	ThreadPool::Instance().spawn(THREADPOOL_SIZE);
+	int concurrency = std::thread::hardware_concurrency();
+	int count = std::max(concurrency, MIN_THREADPOOL_SIZE);
+	PLOG_DEBUG << "Spawning " << count << " threads";
+	ThreadPool::Instance().spawn(count);
+
 #if RTC_ENABLE_WEBSOCKET
 	PollService::Instance().start();
 #endif

+ 1 - 1
src/impl/internals.hpp

@@ -43,7 +43,7 @@ const size_t DEFAULT_MAX_MESSAGE_SIZE = 65536; // Remote max message size if not
 
 const size_t RECV_QUEUE_LIMIT = 1024 * 1024; // Max per-channel queue size
 
-const int THREADPOOL_SIZE = 4; // Number of threads in the global thread pool (>= 2)
+const int MIN_THREADPOOL_SIZE = 4; // Minimum number of threads in the global thread pool (>= 2)
 
 const size_t DEFAULT_MTU = RTC_DEFAULT_MTU; // defined in rtc.h
 

+ 2 - 2
src/impl/peerconnection.cpp

@@ -1110,7 +1110,7 @@ void PeerConnection::triggerTrack(weak_ptr<Track> weakTrack) {
 
 void PeerConnection::triggerPendingDataChannels() {
 	while (dataChannelCallback) {
-		auto next = mPendingDataChannels.tryPop();
+		auto next = mPendingDataChannels.pop();
 		if (!next)
 			break;
 
@@ -1128,7 +1128,7 @@ void PeerConnection::triggerPendingDataChannels() {
 
 void PeerConnection::triggerPendingTracks() {
 	while (trackCallback) {
-		auto next = mPendingTracks.tryPop();
+		auto next = mPendingTracks.pop();
 		if (!next)
 			break;
 

+ 1 - 1
src/impl/processor.cpp

@@ -21,7 +21,7 @@ void Processor::join() {
 
 void Processor::schedule() {
 	std::unique_lock lock(mMutex);
-	if (auto next = mTasks.tryPop()) {
+	if (auto next = mTasks.pop()) {
 		ThreadPool::Instance().enqueue(std::move(*next));
 	} else {
 		// No more tasks

+ 12 - 44
src/impl/queue.hpp

@@ -34,19 +34,14 @@ public:
 	size_t amount() const; // amount
 	void push(T element);
 	optional<T> pop();
-	optional<T> tryPop();
 	optional<T> peek();
 	optional<T> exchange(T element);
-	bool wait(const optional<std::chrono::milliseconds> &duration = nullopt);
 
 private:
-	void pushImpl(T element);
-	optional<T> popImpl();
-
 	const size_t mLimit;
 	size_t mAmount;
 	std::queue<T> mQueue;
-	std::condition_variable mPopCondition, mPushCondition;
+	std::condition_variable mPushCondition;
 	amount_function mAmountFunction;
 	bool mStopping = false;
 
@@ -66,7 +61,6 @@ template <typename T> Queue<T>::~Queue() { stop(); }
 template <typename T> void Queue<T>::stop() {
 	std::lock_guard lock(mMutex);
 	mStopping = true;
-	mPopCondition.notify_all();
 	mPushCondition.notify_all();
 }
 
@@ -98,18 +92,22 @@ template <typename T> size_t Queue<T>::amount() const {
 template <typename T> void Queue<T>::push(T element) {
 	std::unique_lock lock(mMutex);
 	mPushCondition.wait(lock, [this]() { return !mLimit || mQueue.size() < mLimit || mStopping; });
-	pushImpl(std::move(element));
+	if (mStopping)
+		return;
+
+	mAmount += mAmountFunction(element);
+	mQueue.emplace(std::move(element));
 }
 
 template <typename T> optional<T> Queue<T>::pop() {
 	std::unique_lock lock(mMutex);
-	mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
-	return popImpl();
-}
+	if (mQueue.empty())
+		return nullopt;
 
-template <typename T> optional<T> Queue<T>::tryPop() {
-	std::unique_lock lock(mMutex);
-	return popImpl();
+	mAmount -= mAmountFunction(mQueue.front());
+	optional<T> element{std::move(mQueue.front())};
+	mQueue.pop();
+	return element;
 }
 
 template <typename T> optional<T> Queue<T>::peek() {
@@ -126,36 +124,6 @@ template <typename T> optional<T> Queue<T>::exchange(T element) {
 	return std::make_optional(std::move(element));
 }
 
-template <typename T> bool Queue<T>::wait(const optional<std::chrono::milliseconds> &duration) {
-	std::unique_lock lock(mMutex);
-	if (duration) {
-		return mPopCondition.wait_for(lock, *duration,
-		                              [this]() { return !mQueue.empty() || mStopping; });
-	} else {
-		mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
-		return true;
-	}
-}
-
-template <typename T> void Queue<T>::pushImpl(T element) {
-	if (mStopping)
-		return;
-
-	mAmount += mAmountFunction(element);
-	mQueue.emplace(std::move(element));
-	mPopCondition.notify_one();
-}
-
-template <typename T> optional<T> Queue<T>::popImpl() {
-	if (mQueue.empty())
-		return nullopt;
-
-	mAmount -= mAmountFunction(mQueue.front());
-	optional<T> element{std::move(mQueue.front())};
-	mQueue.pop();
-	return element;
-}
-
 } // namespace rtc::impl
 
 #endif

+ 110 - 90
src/impl/tlstransport.cpp

@@ -8,6 +8,7 @@
 
 #include "tlstransport.hpp"
 #include "tcptransport.hpp"
+#include "threadpool.hpp"
 
 #if RTC_ENABLE_WEBSOCKET
 
@@ -20,6 +21,16 @@ using namespace std::chrono;
 
 namespace rtc::impl {
 
+void TlsTransport::enqueueRecv() {
+	if (mPendingRecvCount > 0)
+		return;
+
+	if (auto shared_this = weak_from_this().lock()) {
+		++mPendingRecvCount;
+		ThreadPool::Instance().enqueue(&TlsTransport::doRecv, std::move(shared_this));
+	}
+}
+
 #if USE_GNUTLS
 
 namespace {
@@ -54,7 +65,8 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host
 
 	PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
 
-	gnutls::check(gnutls_init(&mSession, mIsClient ? GNUTLS_CLIENT : GNUTLS_SERVER));
+	unsigned int flags = GNUTLS_NONBLOCK | (mIsClient ? GNUTLS_CLIENT : GNUTLS_SERVER);
+	gnutls::check(gnutls_init(&mSession, flags));
 
 	try {
 		const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
@@ -89,22 +101,17 @@ TlsTransport::~TlsTransport() {
 }
 
 void TlsTransport::start() {
-	if (mStarted.exchange(true))
-		return;
-
-	PLOG_DEBUG << "Starting TLS recv thread";
+	PLOG_DEBUG << "Starting TLS transport";
 	registerIncoming();
-	mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
+	changeState(State::Connecting);
+	enqueueRecv(); // to initiate the handshake
 }
 
 void TlsTransport::stop() {
-	if (!mStarted.exchange(false))
-		return;
-
-	PLOG_DEBUG << "Stopping TLS recv thread";
+	PLOG_DEBUG << "Stopping TLS transport";
 	unregisterIncoming();
 	mIncomingQueue.stop();
-	mRecvThread.join();
+	enqueueRecv();
 }
 
 bool TlsTransport::send(message_ptr message) {
@@ -130,11 +137,13 @@ bool TlsTransport::send(message_ptr message) {
 void TlsTransport::incoming(message_ptr message) {
 	if (!message) {
 		mIncomingQueue.stop();
+		enqueueRecv();
 		return;
 	}
 
 	PLOG_VERBOSE << "Incoming size=" << message->size();
 	mIncomingQueue.push(message);
+	enqueueRecv();
 }
 
 bool TlsTransport::outgoing(message_ptr message) {
@@ -147,52 +156,52 @@ void TlsTransport::postHandshake() {
 	// Dummy
 }
 
-void TlsTransport::runRecvLoop() {
+void TlsTransport::doRecv() {
+	std::lock_guard lock(mRecvMutex);
+	--mPendingRecvCount;
+
 	const size_t bufferSize = 4096;
 	char buffer[bufferSize];
 
-	// Handshake loop
 	try {
-		changeState(State::Connecting);
+		// Handle handshake if connecting
+		if (state() == State::Connecting) {
+			int ret;
+			do {
+				ret = gnutls_handshake(mSession);
 
-		int ret;
-		do {
-			ret = gnutls_handshake(mSession);
-		} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN ||
-		         !gnutls::check(ret, "TLS handshake failed"));
+				if (ret == GNUTLS_E_AGAIN)
+					return;
 
-	} catch (const std::exception &e) {
-		PLOG_ERROR << "TLS handshake: " << e.what();
-		changeState(State::Failed);
-		return;
-	}
+			} while (!gnutls::check(ret, "TLS handshake failed")); // Re-call on non-fatal error
 
-	// Receive loop
-	try {
-		PLOG_INFO << "TLS handshake finished";
-		changeState(State::Connected);
-		postHandshake();
+			PLOG_INFO << "TLS handshake finished";
+			changeState(State::Connected);
+			postHandshake();
+		}
 
-		while (true) {
-			ssize_t ret;
-			do {
-				ret = gnutls_record_recv(mSession, buffer, bufferSize);
-			} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
+		if (state() == State::Connected) {
+			while (true) {
+				ssize_t ret = gnutls_record_recv(mSession, buffer, bufferSize);
 
-			// Consider premature termination as remote closing
-			if (ret == GNUTLS_E_PREMATURE_TERMINATION) {
-				PLOG_DEBUG << "TLS connection terminated";
-				break;
-			}
+				if (ret == GNUTLS_E_AGAIN)
+					return;
 
-			if (gnutls::check(ret)) {
-				if (ret == 0) {
-					// Closed
-					PLOG_DEBUG << "TLS connection cleanly closed";
+				// Consider premature termination as remote closing
+				if (ret == GNUTLS_E_PREMATURE_TERMINATION) {
+					PLOG_DEBUG << "TLS connection terminated";
 					break;
 				}
-				auto *b = reinterpret_cast<byte *>(buffer);
-				recv(make_message(b, b + ret));
+
+				if (gnutls::check(ret)) {
+					if (ret == 0) {
+						// Closed
+						PLOG_DEBUG << "TLS connection cleanly closed";
+						break;
+					}
+					auto *b = reinterpret_cast<byte *>(buffer);
+					recv(make_message(b, b + ret));
+				}
 			}
 		}
 	} catch (const std::exception &e) {
@@ -250,6 +259,9 @@ ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_
 			position += len;
 			gnutls_transport_set_errno(t->mSession, 0);
 			return len;
+		} else if (t->mIncomingQueue.running()) {
+			gnutls_transport_set_errno(t->mSession, EAGAIN);
+			return -1;
 		} else {
 			// Closed
 			gnutls_transport_set_errno(t->mSession, 0);
@@ -263,18 +275,16 @@ ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_
 	}
 }
 
-int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
+int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int /* ms */) {
 	TlsTransport *t = static_cast<TlsTransport *>(ptr);
 	try {
 		message_ptr &message = t->mIncomingMessage;
 		size_t &position = t->mIncomingMessagePosition;
 
-		if(message && position < message->size())
+		if (message && position < message->size())
 			return 1;
 
-		bool isReadable = t->mIncomingQueue.wait(
-		    ms != GNUTLS_INDEFINITE_TIMEOUT ? std::make_optional(milliseconds(ms)) : nullopt);
-		return isReadable ? 1 : 0;
+		return !t->mIncomingQueue.empty() ? 1 : 0;
 
 	} catch (const std::exception &e) {
 		PLOG_WARNING << e.what();
@@ -375,23 +385,22 @@ TlsTransport::~TlsTransport() {
 }
 
 void TlsTransport::start() {
-	if (mStarted.exchange(true))
-		return;
-
-	PLOG_DEBUG << "Starting TLS recv thread";
+	PLOG_DEBUG << "Starting TLS transport";
 	registerIncoming();
-	mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
+	changeState(State::Connecting);
+
+	// Initiate the handshake
+	int ret = SSL_do_handshake(mSsl);
+	openssl::check(mSsl, ret, "Handshake initiation failed");
+
+	flushOutput();
 }
 
 void TlsTransport::stop() {
-	if (!mStarted.exchange(false))
-		return;
-
-	PLOG_DEBUG << "Stopping TLS recv thread";
+	PLOG_DEBUG << "Stopping TLS transport";
 	unregisterIncoming();
 	mIncomingQueue.stop();
-	mRecvThread.join();
-	SSL_shutdown(mSsl);
+	enqueueRecv();
 }
 
 bool TlsTransport::send(message_ptr message) {
@@ -407,23 +416,19 @@ bool TlsTransport::send(message_ptr message) {
 	if (!openssl::check(mSsl, ret))
 		throw std::runtime_error("TLS send failed");
 
-	const size_t bufferSize = 4096;
-	byte buffer[bufferSize];
-	bool result = true;
-	while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
-		result = outgoing(make_message(buffer, buffer + ret));
-
-	return result;
+	return flushOutput();
 }
 
 void TlsTransport::incoming(message_ptr message) {
 	if (!message) {
 		mIncomingQueue.stop();
+		enqueueRecv();
 		return;
 	}
 
 	PLOG_VERBOSE << "Incoming size=" << message->size();
 	mIncomingQueue.push(message);
+	enqueueRecv();
 }
 
 bool TlsTransport::outgoing(message_ptr message) { return Transport::outgoing(std::move(message)); }
@@ -432,24 +437,36 @@ void TlsTransport::postHandshake() {
 	// Dummy
 }
 
-void TlsTransport::runRecvLoop() {
-	const size_t bufferSize = 4096;
-	byte buffer[bufferSize];
+void TlsTransport::doRecv() {
+	std::lock_guard lock(mRecvMutex);
+	--mPendingRecvCount;
+
+	if (state() != State::Connecting && state() != State::Connected)
+		return;
 
 	try {
-		changeState(State::Connecting);
+		const size_t bufferSize = 4096;
+		byte buffer[bufferSize];
+
+		// Process incoming messages
+		while (mIncomingQueue.running()) {
+			auto next = mIncomingQueue.pop();
+			if (!next)
+				return;
+
+			message_ptr message = std::move(*next);
+			if (message->size() > 0)
+				BIO_write(mInBio, message->data(), int(message->size())); // Input
+			else
+				recv(message); // Pass zero-sized messages through
 
-		int ret;
-		while (true) {
 			if (state() == State::Connecting) {
-				// Initiate or continue the handshake
-				ret = SSL_do_handshake(mSsl);
+				// Continue the handshake
+				int ret = SSL_do_handshake(mSsl);
 				if (!openssl::check(mSsl, ret, "Handshake failed"))
 					break;
 
-				// Output
-				while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
-					outgoing(make_message(buffer, buffer + ret));
+				flushOutput();
 
 				if (SSL_is_init_finished(mSsl)) {
 					PLOG_INFO << "TLS handshake finished";
@@ -459,29 +476,21 @@ void TlsTransport::runRecvLoop() {
 			}
 
 			if (state() == State::Connected) {
-				// Input
+				int ret;
 				while ((ret = SSL_read(mSsl, buffer, bufferSize)) > 0)
 					recv(make_message(buffer, buffer + ret));
 
 				if (!openssl::check(mSsl, ret))
 					break;
 			}
-
-			auto next = mIncomingQueue.pop();
-			if (!next)
-				break;
-
-			message_ptr message = std::move(*next);
-			if (message->size() > 0)
-				BIO_write(mInBio, message->data(), int(message->size())); // Input
-			else
-				recv(message); // Pass zero-sized messages through
 		}
 
 	} catch (const std::exception &e) {
 		PLOG_ERROR << "TLS recv: " << e.what();
 	}
 
+	SSL_shutdown(mSsl);
+
 	if (state() == State::Connected) {
 		PLOG_INFO << "TLS closed";
 		recv(nullptr);
@@ -490,6 +499,17 @@ void TlsTransport::runRecvLoop() {
 	}
 }
 
+bool TlsTransport::flushOutput() {
+	const size_t bufferSize = 4096;
+	byte buffer[bufferSize];
+	int ret;
+	bool result = true;
+	while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
+		result = outgoing(make_message(buffer, buffer + ret));
+
+	return result;
+}
+
 void TlsTransport::InfoCallback(const SSL *ssl, int where, int ret) {
 	TlsTransport *t =
 	    static_cast<TlsTransport *>(SSL_get_ex_data(ssl, TlsTransport::TransportExIndex));

+ 14 - 4
src/impl/tlstransport.hpp

@@ -24,7 +24,7 @@ namespace rtc::impl {
 
 class TcpTransport;
 
-class TlsTransport : public Transport {
+class TlsTransport : public Transport, public std::enable_shared_from_this<TlsTransport> {
 public:
 	static void Init();
 	static void Cleanup();
@@ -44,14 +44,15 @@ protected:
 	virtual bool outgoing(message_ptr message) override;
 	virtual void postHandshake();
 
-	void runRecvLoop();
+	void enqueueRecv();
+	void doRecv();
 
 	const optional<string> mHost;
 	const bool mIsClient;
 
 	Queue<message_ptr> mIncomingQueue;
-	std::thread mRecvThread;
-	std::atomic<bool> mStarted = false;
+	std::atomic<int> mPendingRecvCount = 0;
+	std::mutex mRecvMutex;
 
 #if USE_GNUTLS
 	gnutls_session_t mSession;
@@ -68,10 +69,19 @@ protected:
 	SSL *mSsl;
 	BIO *mInBio, *mOutBio;
 
+	bool flushOutput();
+
+	static BIO_METHOD *BioMethods;
 	static int TransportExIndex;
+	static std::mutex GlobalMutex;
 
 	static int CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx);
 	static void InfoCallback(const SSL *ssl, int where, int ret);
+
+	static int BioMethodNew(BIO *bio);
+	static int BioMethodFree(BIO *bio);
+	static int BioMethodWrite(BIO *bio, const char *in, int inl);
+	static long BioMethodCtrl(BIO *bio, int cmd, long num, void *ptr);
 #endif
 };
 

+ 1 - 1
src/impl/track.cpp

@@ -63,7 +63,7 @@ void Track::close() {
 }
 
 optional<message_variant> Track::receive() {
-	if (auto next = mRecvQueue.tryPop()) {
+	if (auto next = mRecvQueue.pop()) {
 		message_ptr message = *next;
 		if (message->type == Message::Control)
 			return to_variant(**next); // The same message may be frowarded into multiple Tracks

+ 2 - 2
src/impl/websocket.cpp

@@ -133,7 +133,7 @@ bool WebSocket::isClosed() const { return state == State::Closed; }
 size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
 
 optional<message_variant> WebSocket::receive() {
-	while (auto next = mRecvQueue.tryPop()) {
+	while (auto next = mRecvQueue.pop()) {
 		message_ptr message = *next;
 		if (message->type != Message::Control)
 			return to_variant(std::move(*message));
@@ -147,7 +147,7 @@ optional<message_variant> WebSocket::peek() {
 		if (message->type != Message::Control)
 			return to_variant(std::move(*message));
 
-		mRecvQueue.tryPop();
+		mRecvQueue.pop();
 	}
 	return nullopt;
 }