Browse Source

Properly handle exceptions from threadpool tasks

Paul-Louis Ageneau 4 years ago
parent
commit
bbec827fef
4 changed files with 38 additions and 29 deletions
  1. 20 3
      include/rtc/include.hpp
  2. 8 19
      src/processor.hpp
  3. 1 5
      src/threadpool.cpp
  4. 9 2
      src/threadpool.hpp

+ 20 - 3
include/rtc/include.hpp

@@ -62,7 +62,7 @@ using std::uint8_t;
 const size_t MAX_NUMERICNODE_LEN = 48; // Max IPv6 string representation length
 const size_t MAX_NUMERICSERV_LEN = 6;  // Max port string representation length
 
-const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default
+const uint16_t DEFAULT_SCTP_PORT = 5000;          // SCTP port to use by default
 const size_t DEFAULT_MAX_MESSAGE_SIZE = 65536;    // Remote max message size if not specified in SDP
 const size_t LOCAL_MAX_MESSAGE_SIZE = 256 * 1024; // Local max message size
 
@@ -72,7 +72,7 @@ const int THREADPOOL_SIZE = 4; // Number of threads in the global thread pool
 
 // overloaded helper
 template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
-template <class... Ts> overloaded(Ts...)->overloaded<Ts...>;
+template <class... Ts> overloaded(Ts...) -> overloaded<Ts...>;
 
 // weak_ptr bind helper
 template <typename F, typename T, typename... Args> auto weak_bind(F &&f, T *t, Args &&... _args) {
@@ -85,6 +85,23 @@ template <typename F, typename T, typename... Args> auto weak_bind(F &&f, T *t,
 	};
 }
 
+// scope_guard helper
+class scope_guard {
+public:
+	scope_guard(std::function<void()> func) : function(std::move(func)) {}
+	scope_guard(scope_guard &&other) = delete;
+	scope_guard(const scope_guard &) = delete;
+	void operator=(const scope_guard &) = delete;
+
+	~scope_guard() {
+		if (function)
+			function();
+	}
+
+private:
+	std::function<void()> function;
+};
+
 template <typename... P> class synchronized_callback {
 public:
 	synchronized_callback() = default;
@@ -127,6 +144,6 @@ private:
 	std::function<void(P...)> callback;
 	mutable std::recursive_mutex mutex;
 };
-}
+} // namespace rtc
 
 #endif

+ 8 - 19
src/processor.hpp

@@ -45,7 +45,7 @@ public:
 	void join();
 
 	template <class F, class... Args>
-	auto enqueue(F &&f, Args &&... args) -> invoke_future_t<F, Args...>;
+	void enqueue(F &&f, Args &&... args);
 
 protected:
 	void schedule();
@@ -60,31 +60,20 @@ protected:
 	std::condition_variable mCondition;
 };
 
-template <class F, class... Args>
-auto Processor::enqueue(F &&f, Args &&... args) -> invoke_future_t<F, Args...> {
+template <class F, class... Args> void Processor::enqueue(F &&f, Args &&... args) {
 	std::unique_lock lock(mMutex);
-	using R = std::invoke_result_t<std::decay_t<F>, std::decay_t<Args>...>;
-	auto task = std::make_shared<std::packaged_task<R()>>(
-	    std::bind(std::forward<F>(f), std::forward<Args>(args)...));
-	std::future<R> result = task->get_future();
-
-	auto bundle = [this, task = std::move(task)]() {
-		try {
-			(*task)();
-		} catch (const std::exception &e) {
-			PLOG_WARNING << "Unhandled exception in task: " << e.what();
-		}
-		schedule(); // chain the next task
+	auto bound = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
+	auto task = [this, bound = std::move(bound)]() mutable {
+		scope_guard guard(std::bind(&Processor::schedule, this)); // chain the next task
+		return bound();
 	};
 
 	if (!mPending) {
-		ThreadPool::Instance().enqueue(std::move(bundle));
+		ThreadPool::Instance().enqueue(std::move(task));
 		mPending = true;
 	} else {
-		mTasks.emplace(std::move(bundle));
+		mTasks.emplace(std::move(task));
 	}
-
-	return result;
 }
 
 } // namespace rtc

+ 1 - 5
src/threadpool.cpp

@@ -58,11 +58,7 @@ void ThreadPool::run() {
 
 bool ThreadPool::runOne() {
 	if (auto task = dequeue()) {
-		try {
-			task();
-		} catch (const std::exception &e) {
-			PLOG_WARNING << "Unhandled exception in task: " << e.what();
-		}
+		task();
 		return true;
 	}
 	return false;

+ 9 - 2
src/threadpool.hpp

@@ -73,8 +73,15 @@ template <class F, class... Args>
 auto ThreadPool::enqueue(F &&f, Args &&... args) -> invoke_future_t<F, Args...> {
 	std::unique_lock lock(mMutex);
 	using R = std::invoke_result_t<std::decay_t<F>, std::decay_t<Args>...>;
-	auto task = std::make_shared<std::packaged_task<R()>>(
-	    std::bind(std::forward<F>(f), std::forward<Args>(args)...));
+	auto bound = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
+	auto task = std::make_shared<std::packaged_task<R()>>([bound = std::move(bound)]() mutable {
+        try {
+            return bound();
+        } catch (const std::exception &e) {
+            PLOG_WARNING << e.what();
+            throw;
+        }
+    });
 	std::future<R> result = task->get_future();
 
 	mTasks.emplace([task = std::move(task), token = Init::Token()]() { return (*task)(); });