Browse Source

Added Processor and finished ThreadPool integration

Paul-Louis Ageneau 5 years ago
parent
commit
aecc2b8fda

+ 1 - 0
CMakeLists.txt

@@ -44,6 +44,7 @@ set(LIBDATACHANNEL_SOURCES
 	${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/tls.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/threadpool.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/processor.cpp
 )
 
 set(LIBDATACHANNEL_WEBSOCKET_SOURCES

+ 2 - 0
include/rtc/include.hpp

@@ -64,6 +64,8 @@ const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default
 const size_t DEFAULT_MAX_MESSAGE_SIZE = 65536;    // Remote max message size if not specified in SDP
 const size_t LOCAL_MAX_MESSAGE_SIZE = 256 * 1024; // Local max message size
 
+const int THREADPOOL_SIZE = 4; // Number of threads in the global thread pool
+
 // overloaded helper
 template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
 template <class... Ts> overloaded(Ts...)->overloaded<Ts...>;

+ 3 - 2
include/rtc/peerconnection.hpp

@@ -41,6 +41,7 @@
 namespace rtc {
 
 class Certificate;
+class Processor;
 class IceTransport;
 class DtlsTransport;
 class SctpTransport;
@@ -139,10 +140,10 @@ private:
 
 	void outgoingMedia(message_ptr message);
 
+	const init_token mInitToken = Init::Token();
 	const Configuration mConfig;
 	const future_certificate_ptr mCertificate;
-
-	init_token mInitToken = Init::Token();
+	const std::unique_ptr<Processor> mProcessor;
 
 	std::optional<Description> mLocalDescription, mRemoteDescription;
 	mutable std::mutex mLocalDescriptionMutex, mRemoteDescriptionMutex;

+ 2 - 14
src/certificate.cpp

@@ -17,6 +17,7 @@
  */
 
 #include "certificate.hpp"
+#include "threadpool.hpp"
 
 #include <cassert>
 #include <chrono>
@@ -230,19 +231,6 @@ namespace rtc {
 
 namespace {
 
-// Helper function roughly equivalent to std::async with policy std::launch::async
-// since std::async might be unreliable on some platforms (e.g. Mingw32 on Windows)
-template <class F, class... Args>
-std::future<std::invoke_result_t<std::decay_t<F>, std::decay_t<Args>...>>
-thread_call(F &&f, Args &&... args) {
-	using R = std::invoke_result_t<std::decay_t<F>, std::decay_t<Args>...>;
-	std::packaged_task<R()> task(std::bind(f, std::forward<Args>(args)...));
-	std::future<R> future = task.get_future();
-	std::thread t(std::move(task));
-	t.detach();
-	return future;
-}
-
 static std::unordered_map<string, future_certificate_ptr> CertificateCache;
 static std::mutex CertificateCacheMutex;
 
@@ -254,7 +242,7 @@ future_certificate_ptr make_certificate(string commonName) {
 	if (auto it = CertificateCache.find(commonName); it != CertificateCache.end())
 		return it->second;
 
-	auto future = thread_call(make_certificate_impl, commonName);
+	auto future = ThreadPool::Instance().enqueue(make_certificate_impl, commonName);
 	auto shared = future.share();
 	CertificateCache.emplace(std::move(commonName), shared);
 	return shared;

+ 5 - 0
src/init.cpp

@@ -21,6 +21,7 @@
 #include "certificate.hpp"
 #include "dtlstransport.hpp"
 #include "sctptransport.hpp"
+#include "threadpool.hpp"
 #include "tls.hpp"
 
 #if RTC_ENABLE_WEBSOCKET
@@ -72,6 +73,8 @@ Init::Init() {
 		throw std::runtime_error("WSAStartup failed, error=" + std::to_string(WSAGetLastError()));
 #endif
 
+	ThreadPool::Instance().spawn(THREADPOOL_SIZE);
+
 #if USE_GNUTLS
 		// Nothing to do
 #else
@@ -98,6 +101,8 @@ Init::~Init() {
 	DtlsSrtpTransport::Cleanup();
 #endif
 
+	ThreadPool::Instance().join();
+
 #ifdef _WIN32
 	WSACleanup();
 #endif

+ 26 - 23
src/peerconnection.cpp

@@ -18,9 +18,12 @@
 
 #include "peerconnection.hpp"
 #include "certificate.hpp"
+#include "include.hpp"
+#include "processor.hpp"
+#include "threadpool.hpp"
+
 #include "dtlstransport.hpp"
 #include "icetransport.hpp"
-#include "include.hpp"
 #include "sctptransport.hpp"
 
 #if RTC_ENABLE_MEDIA
@@ -39,8 +42,8 @@ using std::weak_ptr;
 PeerConnection::PeerConnection() : PeerConnection(Configuration()) {}
 
 PeerConnection::PeerConnection(const Configuration &config)
-    : mConfig(config), mCertificate(make_certificate()), mState(State::New),
-      mGatheringState(GatheringState::New) {
+    : mConfig(config), mCertificate(make_certificate()), mProcessor(std::make_unique<Processor>()),
+      mState(State::New), mGatheringState(GatheringState::New) {
 	PLOG_VERBOSE << "Creating PeerConnection";
 
 	if (config.portRangeEnd && config.portRangeBegin > config.portRangeEnd)
@@ -145,6 +148,7 @@ void PeerConnection::addRemoteCandidate(Candidate candidate) {
 		iceTransport->addRemoteCandidate(candidate);
 	} else {
 		// OK, we might need a lookup, do it asynchronously
+		// We don't use the thread pool because we have no control on the timeout
 		weak_ptr<IceTransport> weakIceTransport{iceTransport};
 		std::thread t([weakIceTransport, candidate]() mutable {
 			if (candidate.resolve(Candidate::ResolveMode::Lookup))
@@ -445,21 +449,18 @@ void PeerConnection::closeTransports() {
 	auto sctp = std::atomic_exchange(&mSctpTransport, decltype(mSctpTransport)(nullptr));
 	auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr));
 	auto ice = std::atomic_exchange(&mIceTransport, decltype(mIceTransport)(nullptr));
-	if (sctp || dtls || ice) {
-		std::thread t([sctp, dtls, ice, token = mInitToken]() mutable {
-			if (sctp)
-				sctp->stop();
-			if (dtls)
-				dtls->stop();
-			if (ice)
-				ice->stop();
-
-			sctp.reset();
-			dtls.reset();
-			ice.reset();
-		});
-		t.detach();
-	}
+	ThreadPool::Instance().enqueue([sctp, dtls, ice, token = mInitToken]() mutable {
+		if (sctp)
+			sctp->stop();
+		if (dtls)
+			dtls->stop();
+		if (ice)
+			ice->stop();
+
+		sctp.reset();
+		dtls.reset();
+		ice.reset();
+	});
 }
 
 void PeerConnection::endLocalCandidates() {
@@ -613,7 +614,7 @@ void PeerConnection::processLocalDescription(Description description) {
 		mLocalDescription->setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE);
 	}
 
-	mLocalDescriptionCallback(*mLocalDescription);
+	mProcessor->enqueue([this]() { mLocalDescriptionCallback(*mLocalDescription); });
 }
 
 void PeerConnection::processLocalCandidate(Candidate candidate) {
@@ -623,7 +624,8 @@ void PeerConnection::processLocalCandidate(Candidate candidate) {
 
 	mLocalDescription->addCandidate(candidate);
 
-	mLocalCandidateCallback(candidate);
+	mProcessor->enqueue(
+	    [this, candidate = std::move(candidate)]() { mLocalCandidateCallback(candidate); });
 }
 
 void PeerConnection::triggerDataChannel(weak_ptr<DataChannel> weakDataChannel) {
@@ -631,7 +633,8 @@ void PeerConnection::triggerDataChannel(weak_ptr<DataChannel> weakDataChannel) {
 	if (!dataChannel)
 		return;
 
-	mDataChannelCallback(dataChannel);
+	mProcessor->enqueue(
+	    [this, dataChannel = std::move(dataChannel)]() { mDataChannelCallback(dataChannel); });
 }
 
 bool PeerConnection::changeState(State state) {
@@ -645,13 +648,13 @@ bool PeerConnection::changeState(State state) {
 
 	} while (!mState.compare_exchange_weak(current, state));
 
-	mStateChangeCallback(state);
+	mProcessor->enqueue([this, state]() { mStateChangeCallback(state); });
 	return true;
 }
 
 bool PeerConnection::changeGatheringState(GatheringState state) {
 	if (mGatheringState.exchange(state) != state)
-		mGatheringStateChangeCallback(state);
+		mProcessor->enqueue([this, state] { mGatheringStateChangeCallback(state); });
 	return true;
 }
 

+ 44 - 0
src/processor.cpp

@@ -0,0 +1,44 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#include "processor.hpp"
+
+namespace rtc {
+
+Processor::~Processor() { join(); }
+
+void Processor::join() {
+	std::unique_lock lock(mMutex);
+	mCondition.wait(lock, [this]() { return !mPending && mTasks.empty(); });
+}
+
+void Processor::schedule() {
+	std::unique_lock lock(mMutex);
+	if (mTasks.empty()) {
+		// No more tasks
+		mPending = false;
+		mCondition.notify_all();
+		return;
+	}
+
+	ThreadPool::Instance().enqueue(std::move(mTasks.front()));
+	mTasks.pop();
+}
+
+} // namespace rtc
+

+ 87 - 0
src/processor.hpp

@@ -0,0 +1,87 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_PROCESSOR_H
+#define RTC_PROCESSOR_H
+
+#include "include.hpp"
+#include "threadpool.hpp"
+
+#include <condition_variable>
+#include <future>
+#include <memory>
+#include <mutex>
+#include <queue>
+
+namespace rtc {
+
+class Processor final {
+public:
+	Processor() = default;
+	~Processor();
+
+	Processor(const Processor &) = delete;
+	Processor &operator=(const Processor &) = delete;
+	Processor(Processor &&) = delete;
+	Processor &operator=(Processor &&) = delete;
+
+	void join();
+
+	template <class F, class... Args>
+	auto enqueue(F &&f, Args &&... args) -> invoke_future_t<F, Args...>;
+
+protected:
+	void schedule();
+
+	std::queue<std::function<void()>> mTasks;
+	bool mPending = false;
+
+	mutable std::mutex mMutex;
+	std::condition_variable mCondition;
+};
+
+template <class F, class... Args>
+auto Processor::enqueue(F &&f, Args &&... args) -> invoke_future_t<F, Args...> {
+	std::unique_lock lock(mMutex);
+	using R = std::invoke_result_t<std::decay_t<F>, std::decay_t<Args>...>;
+	auto task = std::make_shared<std::packaged_task<R()>>(
+	    std::bind(std::forward<F>(f), std::forward<Args>(args)...));
+	std::future<R> result = task->get_future();
+
+	auto bundle = [this, task = std::move(task)]() {
+		try {
+			(*task)();
+		} catch (const std::exception &e) {
+			PLOG_WARNING << "Unhandled exception in task: " << e.what();
+		}
+		schedule();
+	};
+
+	if (!mPending) {
+		ThreadPool::Instance().enqueue(std::move(bundle));
+		mPending = true;
+	} else {
+		mTasks.emplace(std::move(bundle));
+	}
+
+	return result;
+}
+
+} // namespace rtc
+
+#endif

+ 13 - 0
src/tcptransport.cpp

@@ -107,6 +107,7 @@ bool TcpTransport::stop() {
 }
 
 bool TcpTransport::send(message_ptr message) {
+	std::unique_lock lock(mSockMutex);
 	if (state() != State::Connected)
 		return false;
 
@@ -126,6 +127,7 @@ void TcpTransport::incoming(message_ptr message) {
 }
 
 bool TcpTransport::outgoing(message_ptr message) {
+	// mSockMutex must be locked
 	// If nothing is pending, try to send directly
 	// It's safe because if the queue is empty, the thread is not sending
 	if (mSendQueue.empty() && trySendMessage(message))
@@ -174,6 +176,7 @@ void TcpTransport::connect(const string &hostname, const string &service) {
 }
 
 void TcpTransport::connect(const sockaddr *addr, socklen_t addrlen) {
+	std::unique_lock lock(mSockMutex);
 	try {
 		char node[MAX_NUMERICNODE_LEN];
 		char serv[MAX_NUMERICSERV_LEN];
@@ -248,15 +251,18 @@ void TcpTransport::connect(const sockaddr *addr, socklen_t addrlen) {
 }
 
 void TcpTransport::close() {
+	std::unique_lock lock(mSockMutex);
 	if (mSock != INVALID_SOCKET) {
 		PLOG_DEBUG << "Closing TCP socket";
 		::closesocket(mSock);
 		mSock = INVALID_SOCKET;
 	}
 	changeState(State::Disconnected);
+	interruptSelect();
 }
 
 bool TcpTransport::trySendQueue() {
+	// mSockMutex must be locked
 	while (auto next = mSendQueue.peek()) {
 		auto message = *next;
 		if (!trySendMessage(message)) {
@@ -269,6 +275,7 @@ bool TcpTransport::trySendQueue() {
 }
 
 bool TcpTransport::trySendMessage(message_ptr &message) {
+	// mSockMutex must be locked
 	auto data = reinterpret_cast<const char *>(message->data());
 	auto size = message->size();
 	while (size) {
@@ -314,13 +321,19 @@ void TcpTransport::runLoop() {
 		changeState(State::Connected);
 
 		while (true) {
+			std::unique_lock lock(mSockMutex);
+			if (mSock == INVALID_SOCKET)
+				break;
+
 			fd_set readfds, writefds;
 			int n = prepareSelect(readfds, writefds);
 
 			struct timeval tv;
 			tv.tv_sec = 10;
 			tv.tv_usec = 0;
+			lock.unlock();
 			int ret = ::select(n, &readfds, &writefds, NULL, &tv);
+			lock.lock();
 			if (ret < 0) {
 				throw std::runtime_error("Failed to wait on socket");
 			} else if (ret == 0) {

+ 1 - 0
src/tcptransport.hpp

@@ -78,6 +78,7 @@ private:
 	string mHostname, mService;
 
 	socket_t mSock = INVALID_SOCKET;
+	std::mutex mSockMutex;
 	std::thread mThread;
 	SelectInterrupter mInterrupter;
 	Queue<message_ptr> mSendQueue;

+ 12 - 20
src/threadpool.cpp

@@ -25,40 +25,32 @@ ThreadPool &ThreadPool::Instance() {
 	return instance;
 }
 
-ThreadPool::ThreadPool(int count) { spawn(count); }
-
 ThreadPool::~ThreadPool() { join(); }
 
 int ThreadPool::count() const {
-	std::unique_lock lock(mMutex);
+	std::unique_lock lock(mWorkersMutex);
 	return mWorkers.size();
 }
 
 void ThreadPool::spawn(int count) {
-	std::unique_lock lock(mMutex);
+	std::unique_lock lock(mWorkersMutex);
+	mJoining = false;
 	while (count-- > 0)
 		mWorkers.emplace_back(std::bind(&ThreadPool::run, this));
 }
 
 void ThreadPool::join() {
-	try {
-		std::unique_lock lock(mMutex);
-		mJoining = true;
-		mCondition.notify_all();
+	std::unique_lock lock(mWorkersMutex);
+	mJoining = true;
+	mCondition.notify_all();
 
-		auto workers = std::move(mWorkers);
-		mWorkers.clear();
+	for (auto &w : mWorkers)
+		if (w.get_id() == std::this_thread::get_id())
+			w.detach(); // detach ourselves
+		else
+			w.join(); // join others
 
-		lock.unlock();
-		for (auto &w : workers)
-			w.join();
-
-	} catch (...) {
-		mJoining = false;
-		throw;
-	}
-
-	mJoining = false;
+	mWorkers.clear();
 }
 
 void ThreadPool::run() {

+ 10 - 6
src/threadpool.hpp

@@ -24,6 +24,7 @@
 #include <condition_variable>
 #include <functional>
 #include <future>
+#include <memory>
 #include <mutex>
 #include <queue>
 #include <stdexcept>
@@ -38,6 +39,7 @@ using invoke_future_t = std::future<std::invoke_result_t<std::decay_t<F>, std::d
 class ThreadPool final {
 public:
 	static ThreadPool &Instance();
+
 	ThreadPool(const ThreadPool &) = delete;
 	ThreadPool &operator=(const ThreadPool &) = delete;
 	ThreadPool(ThreadPool &&) = delete;
@@ -53,26 +55,28 @@ public:
 	auto enqueue(F &&f, Args &&... args) -> invoke_future_t<F, Args...>;
 
 protected:
-	explicit ThreadPool(int count = 0);
+	ThreadPool() = default;
 	~ThreadPool();
 
 	std::function<void()> dequeue(); // returns null function if joining
 
 	std::vector<std::thread> mWorkers;
 	std::queue<std::function<void()>> mTasks;
+	std::atomic<bool> mJoining = false;
 
-	std::mutex mMutex;
+	mutable std::mutex mMutex, mWorkersMutex;
 	std::condition_variable mCondition;
-	bool mJoining = false;
 };
 
 template <class F, class... Args>
 auto ThreadPool::enqueue(F &&f, Args &&... args) -> invoke_future_t<F, Args...> {
 	std::unique_lock lock(mMutex);
 	using R = std::invoke_result_t<std::decay_t<F>, std::decay_t<Args>...>;
-	auto task = std::packaged_task<R()>(std::bind(std::forward<F>(f), std::forward<Args>(args)...));
-	std::future<R> result = task.get_future();
-	mTasks.emplace([task = std::move(task)]() { task(); });
+	auto task = std::make_shared<std::packaged_task<R()>>(
+	    std::bind(std::forward<F>(f), std::forward<Args>(args)...));
+	std::future<R> result = task->get_future();
+
+	mTasks.emplace([task = std::move(task)]() { return (*task)(); });
 	mCondition.notify_one();
 	return result;
 }

+ 14 - 16
src/websocket.cpp

@@ -18,8 +18,9 @@
 
 #if RTC_ENABLE_WEBSOCKET
 
-#include "include.hpp"
 #include "websocket.hpp"
+#include "include.hpp"
+#include "threadpool.hpp"
 
 #include "tcptransport.hpp"
 #include "tlstransport.hpp"
@@ -301,21 +302,18 @@ void WebSocket::closeTransports() {
 	auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr));
 	auto tls = std::atomic_exchange(&mTlsTransport, decltype(mTlsTransport)(nullptr));
 	auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr));
-	if (ws || tls || tcp) {
-		std::thread t([ws, tls, tcp, token = mInitToken]() mutable {
-			if (ws)
-				ws->stop();
-			if (tls)
-				tls->stop();
-			if (tcp)
-				tcp->stop();
-
-			ws.reset();
-			tls.reset();
-			tcp.reset();
-		});
-		t.detach();
-	}
+	ThreadPool::Instance().enqueue([ws, tls, tcp, token = Init::Token()]() mutable {
+		if (ws)
+			ws->stop();
+		if (tls)
+			tls->stop();
+		if (tcp)
+			tcp->stop();
+
+		ws.reset();
+		tls.reset();
+		tcp.reset();
+	});
 }
 
 } // namespace rtc