Browse Source

Merge pull request #113 from paullouisageneau/fix-capi-userptr

Fix potential callback call with null user pointer in C API
Paul-Louis Ageneau 5 years ago
parent
commit
0a6b263bc3
1 changed files with 46 additions and 22 deletions
  1. 46 22
      src/rtc.cpp

+ 46 - 22
src/rtc.cpp

@@ -55,18 +55,15 @@ std::unordered_map<int, void *> userPointerMap;
 std::mutex mutex;
 int lastId = 0;
 
-void *getUserPointer(int id) {
+std::optional<void *> getUserPointer(int id) {
 	std::lock_guard lock(mutex);
 	auto it = userPointerMap.find(id);
-	return it != userPointerMap.end() ? it->second : nullptr;
+	return it != userPointerMap.end() ? std::make_optional(it->second) : nullopt;
 }
 
 void setUserPointer(int i, void *ptr) {
 	std::lock_guard lock(mutex);
-	if (ptr)
-		userPointerMap[i] = ptr;
-	else
-		userPointerMap.erase(i);
+	userPointerMap[i] = ptr;
 }
 
 shared_ptr<PeerConnection> getPeerConnection(int id) {
@@ -89,6 +86,7 @@ int emplacePeerConnection(shared_ptr<PeerConnection> ptr) {
 	std::lock_guard lock(mutex);
 	int pc = ++lastId;
 	peerConnectionMap.emplace(std::make_pair(pc, ptr));
+	userPointerMap.emplace(std::make_pair(pc, nullptr));
 	return pc;
 }
 
@@ -96,6 +94,7 @@ int emplaceDataChannel(shared_ptr<DataChannel> ptr) {
 	std::lock_guard lock(mutex);
 	int dc = ++lastId;
 	dataChannelMap.emplace(std::make_pair(dc, ptr));
+	userPointerMap.emplace(std::make_pair(dc, nullptr));
 	return dc;
 }
 
@@ -126,6 +125,7 @@ int emplaceWebSocket(shared_ptr<WebSocket> ptr) {
 	std::lock_guard lock(mutex);
 	int ws = ++lastId;
 	webSocketMap.emplace(std::make_pair(ws, ptr));
+	userPointerMap.emplace(std::make_pair(ws, nullptr));
 	return ws;
 }
 
@@ -244,7 +244,8 @@ int rtcCreateDataChannel(int pc, const char *label) {
 	return WRAP({
 		auto peerConnection = getPeerConnection(pc);
 		int dc = emplaceDataChannel(peerConnection->createDataChannel(string(label)));
-		rtcSetUserPointer(dc, getUserPointer(pc));
+		if (auto ptr = getUserPointer(pc))
+			rtcSetUserPointer(dc, *ptr);
 		return dc;
 	});
 }
@@ -293,9 +294,10 @@ int rtcSetDataChannelCallback(int pc, rtcDataChannelCallbackFunc cb) {
 		if (cb)
 			peerConnection->onDataChannel([pc, cb](std::shared_ptr<DataChannel> dataChannel) {
 				int dc = emplaceDataChannel(dataChannel);
-				void *ptr = getUserPointer(pc);
-				rtcSetUserPointer(dc, ptr);
-				cb(dc, ptr);
+				if (auto ptr = getUserPointer(pc)) {
+					rtcSetUserPointer(dc, *ptr);
+					cb(dc, *ptr);
+				}
 			});
 		else
 			peerConnection->onDataChannel(nullptr);
@@ -307,7 +309,8 @@ int rtcSetLocalDescriptionCallback(int pc, rtcDescriptionCallbackFunc cb) {
 		auto peerConnection = getPeerConnection(pc);
 		if (cb)
 			peerConnection->onLocalDescription([pc, cb](const Description &desc) {
-				cb(string(desc).c_str(), desc.typeString().c_str(), getUserPointer(pc));
+				if (auto ptr = getUserPointer(pc))
+					cb(string(desc).c_str(), desc.typeString().c_str(), *ptr);
 			});
 		else
 			peerConnection->onLocalDescription(nullptr);
@@ -319,7 +322,8 @@ int rtcSetLocalCandidateCallback(int pc, rtcCandidateCallbackFunc cb) {
 		auto peerConnection = getPeerConnection(pc);
 		if (cb)
 			peerConnection->onLocalCandidate([pc, cb](const Candidate &cand) {
-				cb(cand.candidate().c_str(), cand.mid().c_str(), getUserPointer(pc));
+				if (auto ptr = getUserPointer(pc))
+					cb(cand.candidate().c_str(), cand.mid().c_str(), *ptr);
 			});
 		else
 			peerConnection->onLocalCandidate(nullptr);
@@ -331,7 +335,8 @@ int rtcSetStateChangeCallback(int pc, rtcStateChangeCallbackFunc cb) {
 		auto peerConnection = getPeerConnection(pc);
 		if (cb)
 			peerConnection->onStateChange([pc, cb](PeerConnection::State state) {
-				cb(static_cast<rtcState>(state), getUserPointer(pc));
+				if (auto ptr = getUserPointer(pc))
+					cb(static_cast<rtcState>(state), *ptr);
 			});
 		else
 			peerConnection->onStateChange(nullptr);
@@ -343,7 +348,8 @@ int rtcSetGatheringStateChangeCallback(int pc, rtcGatheringStateCallbackFunc cb)
 		auto peerConnection = getPeerConnection(pc);
 		if (cb)
 			peerConnection->onGatheringStateChange([pc, cb](PeerConnection::GatheringState state) {
-				cb(static_cast<rtcGatheringState>(state), getUserPointer(pc));
+				if (auto ptr = getUserPointer(pc))
+					cb(static_cast<rtcGatheringState>(state), *ptr);
 			});
 		else
 			peerConnection->onGatheringStateChange(nullptr);
@@ -435,7 +441,10 @@ int rtcSetOpenCallback(int id, rtcOpenCallbackFunc cb) {
 	return WRAP({
 		auto channel = getChannel(id);
 		if (cb)
-			channel->onOpen([id, cb]() { cb(getUserPointer(id)); });
+			channel->onOpen([id, cb]() {
+				if (auto ptr = getUserPointer(id))
+					cb(*ptr);
+			});
 		else
 			channel->onOpen(nullptr);
 	});
@@ -445,7 +454,10 @@ int rtcSetClosedCallback(int id, rtcClosedCallbackFunc cb) {
 	return WRAP({
 		auto channel = getChannel(id);
 		if (cb)
-			channel->onClosed([id, cb]() { cb(getUserPointer(id)); });
+			channel->onClosed([id, cb]() {
+				if (auto ptr = getUserPointer(id))
+					cb(*ptr);
+			});
 		else
 			channel->onClosed(nullptr);
 	});
@@ -455,8 +467,10 @@ int rtcSetErrorCallback(int id, rtcErrorCallbackFunc cb) {
 	return WRAP({
 		auto channel = getChannel(id);
 		if (cb)
-			channel->onError(
-			    [id, cb](const string &error) { cb(error.c_str(), getUserPointer(id)); });
+			channel->onError([id, cb](const string &error) {
+				if (auto ptr = getUserPointer(id))
+					cb(error.c_str(), *ptr);
+			});
 		else
 			channel->onError(nullptr);
 	});
@@ -468,9 +482,13 @@ int rtcSetMessageCallback(int id, rtcMessageCallbackFunc cb) {
 		if (cb)
 			channel->onMessage(
 			    [id, cb](const binary &b) {
-				    cb(reinterpret_cast<const char *>(b.data()), int(b.size()), getUserPointer(id));
+				    if (auto ptr = getUserPointer(id))
+					    cb(reinterpret_cast<const char *>(b.data()), int(b.size()), *ptr);
 			    },
-			    [id, cb](const string &s) { cb(s.c_str(), -1, getUserPointer(id)); });
+			    [id, cb](const string &s) {
+				    if (auto ptr = getUserPointer(id))
+					    cb(s.c_str(), -(s.size() + 1), *ptr);
+			    });
 		else
 			channel->onMessage(nullptr);
 	});
@@ -514,7 +532,10 @@ int rtcSetBufferedAmountLowCallback(int id, rtcBufferedAmountLowCallbackFunc cb)
 	return WRAP({
 		auto channel = getChannel(id);
 		if (cb)
-			channel->onBufferedAmountLow([id, cb]() { cb(getUserPointer(id)); });
+			channel->onBufferedAmountLow([id, cb]() {
+				if (auto ptr = getUserPointer(id))
+					cb(*ptr);
+			});
 		else
 			channel->onBufferedAmountLow(nullptr);
 	});
@@ -528,7 +549,10 @@ int rtcSetAvailableCallback(int id, rtcAvailableCallbackFunc cb) {
 	return WRAP({
 		auto channel = getChannel(id);
 		if (cb)
-			channel->onOpen([id, cb]() { cb(getUserPointer(id)); });
+			channel->onOpen([id, cb]() {
+				if (auto ptr = getUserPointer(id))
+					cb(*ptr);
+			});
 		else
 			channel->onOpen(nullptr);
 	});