Explorar el Código

Enhanced API to differentiate binary and string messages

Paul-Louis Ageneau hace 6 años
padre
commit
61fb38305a
Se han modificado 5 ficheros con 40 adiciones y 39 borrados
  1. 2 0
      include/rtc/channel.hpp
  2. 3 0
      include/rtc/include.hpp
  3. 1 1
      include/rtc/rtc.h
  4. 7 0
      src/channel.cpp
  5. 27 38
      src/rtc.cpp

+ 2 - 0
include/rtc/channel.hpp

@@ -38,6 +38,8 @@ public:
 	void onClosed(std::function<void()> callback);
 	void onError(std::function<void(const string &error)> callback);
 	void onMessage(std::function<void(const std::variant<binary, string> &data)> callback);
+	void onMessage(std::function<void(const binary &data)> binaryCallback,
+	               std::function<void(const string &data)> stringCallback);
 
 protected:
 	virtual void triggerOpen(void);

+ 3 - 0
include/rtc/include.hpp

@@ -41,6 +41,9 @@ const size_t MAX_NUMERICNODE_LEN = 48; // Max IPv6 string representation length
 const size_t MAX_NUMERICSERV_LEN = 6;  // Max port string representation length
 
 const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default
+
+template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
+template <class... Ts> overloaded(Ts...)->overloaded<Ts...>;
 }
 
 #endif

+ 1 - 1
include/rtc/rtc.h

@@ -36,7 +36,7 @@ int rtcGetDataChannelLabel(int dc, char *data, int size);
 void rtcSetOpenCallback(int dc, void (*openCallback)(void *));
 void rtcSetErrorCallback(int dc, void (*errorCallback)(const char *, void *));
 void rtcSetMessageCallback(int dc, void (*messageCallback)(const char *, int, void *));
-int rtcSendMessage(int dc, const char *buffer, int size);
+int rtcSendMessage(int dc, const char *data, int size);
 void rtcSetUserPointer(int i, void *ptr);
 
 #ifdef __cplusplus

+ 7 - 0
src/channel.cpp

@@ -32,6 +32,13 @@ void Channel::onMessage(std::function<void(const std::variant<binary, string> &d
 	mMessageCallback = callback;
 }
 
+void Channel::onMessage(std::function<void(const binary &data)> binaryCallback,
+                        std::function<void(const string &data)> stringCallback) {
+	onMessage([binaryCallback, stringCallback](const std::variant<binary, string> &data) {
+		std::visit(overloaded{binaryCallback, stringCallback}, data);
+	});
+}
+
 void Channel::triggerOpen(void) {
 	if (mOpenCallback)
 		mOpenCallback();

+ 27 - 38
src/rtc.cpp

@@ -34,6 +34,11 @@ std::unordered_map<int, shared_ptr<DataChannel>> dataChannelMap;
 std::unordered_map<int, void *> userPointerMap;
 int lastId = 0;
 
+void *getUserPointer(int id) {
+	auto it = userPointerMap.find(id);
+	return it != userPointerMap.end() ? it->second : nullptr;
+}
+
 } // namespace
 
 int rtcCreatePeerConnection(const char **iceServers, int iceServersCount) {
@@ -68,10 +73,7 @@ void rtcSetDataChannelCallback(int pc, void (*dataChannelCallback)(int, void *))
 	it->second->onDataChannel([pc, dataChannelCallback](std::shared_ptr<DataChannel> dataChannel) {
 		int dc = ++lastId;
 		dataChannelMap.emplace(std::make_pair(dc, dataChannel));
-		void *userPointer = nullptr;
-		if (auto jt = userPointerMap.find(pc); jt != userPointerMap.end())
-			userPointer = jt->second;
-		dataChannelCallback(dc, userPointer);
+		dataChannelCallback(dc, getUserPointer(pc));
 	});
 }
 
@@ -82,11 +84,8 @@ void rtcSetLocalDescriptionCallback(int pc, void (*descriptionCallback)(const ch
 		return;
 
 	it->second->onLocalDescription([pc, descriptionCallback](const Description &description) {
-		void *userPointer = nullptr;
-		if (auto jt = userPointerMap.find(pc); jt != userPointerMap.end())
-			userPointer = jt->second;
 		descriptionCallback(string(description).c_str(), description.typeString().c_str(),
-		                    userPointer);
+		                    getUserPointer(pc));
 	});
 }
 
@@ -98,14 +97,11 @@ void rtcSetLocalCandidateCallback(int pc,
 
 	it->second->onLocalCandidate(
 	    [pc, candidateCallback](const std::optional<Candidate> &candidate) {
-		    void *userPointer = nullptr;
-		    if (auto jt = userPointerMap.find(pc); jt != userPointerMap.end())
-			    userPointer = jt->second;
 		    if (candidate) {
 			    candidateCallback(string(*candidate).c_str(), candidate->mid().c_str(),
-			                      userPointer);
+			                      getUserPointer(pc));
 		    } else {
-			    candidateCallback(nullptr, nullptr, userPointer);
+			    candidateCallback(nullptr, nullptr, getUserPointer(pc));
 		    }
 	    });
 }
@@ -146,12 +142,7 @@ void rtcSetOpenCallback(int dc, void (*openCallback)(void *)) {
 	if (it == dataChannelMap.end())
 		return;
 
-	it->second->onOpen([dc, openCallback]() {
-		void *userPointer = nullptr;
-		if (auto jt = userPointerMap.find(dc); jt != userPointerMap.end())
-			userPointer = jt->second;
-		openCallback(userPointer);
-	});
+	it->second->onOpen([dc, openCallback]() { openCallback(getUserPointer(dc)); });
 }
 
 void rtcSetErrorCallback(int dc, void (*errorCallback)(const char *, void *)) {
@@ -160,10 +151,7 @@ void rtcSetErrorCallback(int dc, void (*errorCallback)(const char *, void *)) {
 		return;
 
 	it->second->onError([dc, errorCallback](const string &error) {
-		void *userPointer = nullptr;
-		if (auto jt = userPointerMap.find(dc); jt != userPointerMap.end())
-			userPointer = jt->second;
-		errorCallback(error.c_str(), userPointer);
+		errorCallback(error.c_str(), getUserPointer(dc));
 	});
 }
 
@@ -172,18 +160,13 @@ void rtcSetMessageCallback(int dc, void (*messageCallback)(const char *, int, vo
 	if (it == dataChannelMap.end())
 		return;
 
-	it->second->onMessage([dc, messageCallback](const std::variant<binary, string> &message) {
-		void *userPointer = nullptr;
-		if (auto jt = userPointerMap.find(dc); jt != userPointerMap.end())
-			userPointer = jt->second;
-		std::visit(
-		    [messageCallback, userPointer](const auto &v) {
-			    auto data = reinterpret_cast<const char *>(v.data());
-			    int size = v.size();
-			    messageCallback(data, size, userPointer);
-		    },
-		    message);
-	});
+	it->second->onMessage(
+	    [dc, messageCallback](const binary &b) {
+		    messageCallback(reinterpret_cast<const char *>(b.data()), b.size(), getUserPointer(dc));
+	    },
+	    [dc, messageCallback](const string &s) {
+		    messageCallback(s.c_str(), -1, getUserPointer(dc));
+	    });
 }
 
 int rtcSendMessage(int dc, const char *data, int size) {
@@ -191,9 +174,15 @@ int rtcSendMessage(int dc, const char *data, int size) {
 	if (it == dataChannelMap.end())
 		return 0;
 
-	auto b = reinterpret_cast<const byte *>(data);
-	it->second->send(b, size);
-	return size;
+	if (size >= 0) {
+		auto b = reinterpret_cast<const byte *>(data);
+		it->second->send(b, size);
+		return size;
+	} else {
+		string s(data);
+		it->second->send(s);
+		return s.size();
+	}
 }
 
 void rtcSetUserPointer(int i, void *ptr) {