瀏覽代碼

Added callback wrapper in Channel

Paul-Louis Ageneau 5 年之前
父節點
當前提交
abdf61e841
共有 2 個文件被更改,包括 35 次插入46 次删除
  1. 25 12
      include/rtc/channel.hpp
  2. 10 34
      src/channel.cpp

+ 25 - 12
include/rtc/channel.hpp

@@ -54,18 +54,31 @@ protected:
 	virtual void triggerSent();
 
 private:
-	std::function<void()> mOpenCallback;
-	std::function<void()> mClosedCallback;
-	std::function<void(const string &)> mErrorCallback;
-	std::function<void(const std::variant<binary, string> &)> mMessageCallback;
-	std::function<void()> mAvailableCallback;
-	std::function<void()> mSentCallback;
-	std::mutex mCallbackMutex;
-
-	template <typename T> T getCallback(const T &callback) {
-		std::lock_guard<std::mutex> lock(mCallbackMutex);
-		return callback;
-	}
+	template <typename... P> class synchronized_callback {
+	public:
+		synchronized_callback &operator=(std::function<void(P...)> func) {
+			std::lock_guard<std::recursive_mutex> lock(mutex);
+			callback = func;
+			return *this;
+		}
+
+		void operator()(P... args) {
+			std::lock_guard<std::recursive_mutex> lock(mutex);
+			if (callback)
+				callback(args...);
+		}
+
+	private:
+		std::function<void(P...)> callback;
+		std::recursive_mutex mutex;
+	};
+
+	synchronized_callback<> mOpenCallback;
+	synchronized_callback<> mClosedCallback;
+	synchronized_callback<const string &> mErrorCallback;
+	synchronized_callback<const std::variant<binary, string> &> mMessageCallback;
+	synchronized_callback<> mAvailableCallback;
+	synchronized_callback<> mSentCallback;
 };
 
 } // namespace rtc

+ 10 - 34
src/channel.cpp

@@ -23,30 +23,23 @@ namespace {}
 namespace rtc {
 
 void Channel::onOpen(std::function<void()> callback) {
-	std::lock_guard<std::mutex> lock(mCallbackMutex);
 	mOpenCallback = callback;
 }
 
 void Channel::onClosed(std::function<void()> callback) {
-	std::lock_guard<std::mutex> lock(mCallbackMutex);
 	mClosedCallback = callback;
 }
 
 void Channel::onError(std::function<void(const string &error)> callback) {
-	std::lock_guard<std::mutex> lock(mCallbackMutex);
 	mErrorCallback = callback;
 }
 
 void Channel::onMessage(std::function<void(const std::variant<binary, string> &data)> callback) {
-	std::lock_guard<std::mutex> lock(mCallbackMutex);
 	mMessageCallback = callback;
 
 	// Pass pending messages
-	while (auto message = receive()) {
-		// The callback might be changed from itself
-		if (auto callback = getCallback(mMessageCallback))
-			callback(*message);
-	}
+	while (auto message = receive())
+		mMessageCallback(*message);
 }
 
 void Channel::onMessage(std::function<void(const binary &data)> binaryCallback,
@@ -57,49 +50,32 @@ void Channel::onMessage(std::function<void(const binary &data)> binaryCallback,
 }
 
 void Channel::onAvailable(std::function<void()> callback) {
-	std::lock_guard<std::mutex> lock(mCallbackMutex);
 	mAvailableCallback = callback;
 }
 
 void Channel::onSent(std::function<void()> callback) {
-	std::lock_guard<std::mutex> lock(mCallbackMutex);
 	mSentCallback = callback;
 }
 
-void Channel::triggerOpen() {
-	if (auto callback = getCallback(mOpenCallback))
-		callback();
-}
+void Channel::triggerOpen() { mOpenCallback(); }
 
-void Channel::triggerClosed() {
-	if (auto callback = getCallback(mClosedCallback))
-		callback();
-}
+void Channel::triggerClosed() { mClosedCallback(); }
 
-void Channel::triggerError(const string &error) {
-	if (auto callback = getCallback(mErrorCallback))
-		callback(error);
-}
+void Channel::triggerError(const string &error) { mErrorCallback(error); }
 
 void Channel::triggerAvailable(size_t available) {
-	if (available == 1) {
-		if (auto callback = getCallback(mAvailableCallback))
-			callback();
-	}
+	if (available == 1)
+		mAvailableCallback();
+
 	while (available--) {
 		auto message = receive();
 		if (!message)
 			break;
-		// The callback might be changed from itself
-		if (auto callback = getCallback(mMessageCallback))
-			callback(*message);
+		mMessageCallback(*message);
 	}
 }
 
-void Channel::triggerSent() {
-	if (auto callback = getCallback(mSentCallback))
-		callback();
-}
+void Channel::triggerSent() { mSentCallback(); }
 
 } // namespace rtc