Browse Source

Added shared mutex to protect data channels map

Paul-Louis Ageneau 5 years ago
parent
commit
f3b3208367
2 changed files with 44 additions and 28 deletions
  1. 7 0
      include/rtc/peerconnection.hpp
  2. 37 28
      src/peerconnection.cpp

+ 7 - 0
include/rtc/peerconnection.hpp

@@ -32,6 +32,7 @@
 #include <functional>
 #include <functional>
 #include <list>
 #include <list>
 #include <mutex>
 #include <mutex>
+#include <shared_mutex>
 #include <thread>
 #include <thread>
 #include <unordered_map>
 #include <unordered_map>
 
 
@@ -95,6 +96,11 @@ private:
 	bool checkFingerprint(const std::string &fingerprint) const;
 	bool checkFingerprint(const std::string &fingerprint) const;
 	void forwardMessage(message_ptr message);
 	void forwardMessage(message_ptr message);
 	void forwardBufferedAmount(uint16_t stream, size_t amount);
 	void forwardBufferedAmount(uint16_t stream, size_t amount);
+
+	std::shared_ptr<DataChannel> emplaceDataChannel(Description::Role role, const string &label,
+	                                                const string &protocol,
+	                                                const Reliability &reliability);
+	std::shared_ptr<DataChannel> findDataChannel(uint16_t stream);
 	void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
 	void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
 	void openDataChannels();
 	void openDataChannels();
 	void closeDataChannels();
 	void closeDataChannels();
@@ -118,6 +124,7 @@ private:
 	std::recursive_mutex mInitMutex;
 	std::recursive_mutex mInitMutex;
 
 
 	std::unordered_map<unsigned int, std::weak_ptr<DataChannel>> mDataChannels;
 	std::unordered_map<unsigned int, std::weak_ptr<DataChannel>> mDataChannels;
+	std::shared_mutex mDataChannelsMutex;
 
 
 	std::atomic<State> mState;
 	std::atomic<State> mState;
 	std::atomic<GatheringState> mGatheringState;
 	std::atomic<GatheringState> mGatheringState;

+ 37 - 28
src/peerconnection.cpp

@@ -62,7 +62,6 @@ PeerConnection::~PeerConnection() {
 void PeerConnection::close() {
 void PeerConnection::close() {
 	// Close DataChannels
 	// Close DataChannels
 	closeDataChannels();
 	closeDataChannels();
-	mDataChannels.clear();
 
 
 	// Close Transports
 	// Close Transports
 	for (int i = 0; i < 2; ++i) { // Make sure a transport wasn't spawn behind our back
 	for (int i = 0; i < 2; ++i) { // Make sure a transport wasn't spawn behind our back
@@ -115,12 +114,16 @@ void PeerConnection::setRemoteDescription(Description description) {
 		if (!sctpTransport && iceTransport->role() == Description::Role::Active) {
 		if (!sctpTransport && iceTransport->role() == Description::Role::Active) {
 			// Since we assumed passive role during DataChannel creation, we need to shift the
 			// Since we assumed passive role during DataChannel creation, we need to shift the
 			// stream numbers by one to shift them from odd to even.
 			// stream numbers by one to shift them from odd to even.
+			std::unique_lock lock(mDataChannelsMutex);
 			decltype(mDataChannels) newDataChannels;
 			decltype(mDataChannels) newDataChannels;
-			iterateDataChannels([&](shared_ptr<DataChannel> channel) {
+			auto it = mDataChannels.begin();
+			while (it != mDataChannels.end()) {
+				auto channel = it->second.lock();
 				if (channel->stream() % 2 == 1)
 				if (channel->stream() % 2 == 1)
 					channel->mStream -= 1;
 					channel->mStream -= 1;
 				newDataChannels.emplace(channel->stream(), channel);
 				newDataChannels.emplace(channel->stream(), channel);
-			});
+				++it;
+			}
 			std::swap(mDataChannels, newDataChannels);
 			std::swap(mDataChannels, newDataChannels);
 		}
 		}
 	}
 	}
@@ -172,19 +175,7 @@ shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
 	auto iceTransport = std::atomic_load(&mIceTransport);
 	auto iceTransport = std::atomic_load(&mIceTransport);
 	auto role = iceTransport ? iceTransport->role() : Description::Role::Passive;
 	auto role = iceTransport ? iceTransport->role() : Description::Role::Passive;
 
 
-	// The active side must use streams with even identifiers, whereas the passive side must use
-	// streams with odd identifiers.
-	// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09#section-6
-	unsigned int stream = (role == Description::Role::Active) ? 0 : 1;
-	while (mDataChannels.find(stream) != mDataChannels.end()) {
-		stream += 2;
-		if (stream >= 65535)
-			throw std::runtime_error("Too many DataChannels");
-	}
-
-	auto channel =
-	    std::make_shared<DataChannel>(shared_from_this(), stream, label, protocol, reliability);
-	mDataChannels.insert(std::make_pair(stream, channel));
+	auto channel = emplaceDataChannel(role, label, protocol, reliability);
 
 
 	if (!iceTransport) {
 	if (!iceTransport) {
 		// RFC 5763: The endpoint that is the offerer MUST use the setup attribute value of
 		// RFC 5763: The endpoint that is the offerer MUST use the setup attribute value of
@@ -367,14 +358,7 @@ void PeerConnection::forwardMessage(message_ptr message) {
 		return;
 		return;
 	}
 	}
 
 
-	shared_ptr<DataChannel> channel;
-	if (auto it = mDataChannels.find(message->stream); it != mDataChannels.end()) {
-		channel = it->second.lock();
-		if (!channel || channel->isClosed()) {
-			mDataChannels.erase(it);
-			channel = nullptr;
-		}
-	}
+	auto channel = findDataChannel(message->stream);
 
 
 	auto iceTransport = std::atomic_load(&mIceTransport);
 	auto iceTransport = std::atomic_load(&mIceTransport);
 	auto sctpTransport = std::atomic_load(&mSctpTransport);
 	auto sctpTransport = std::atomic_load(&mSctpTransport);
@@ -402,21 +386,46 @@ void PeerConnection::forwardMessage(message_ptr message) {
 }
 }
 
 
 void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
 void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
+	if (auto channel = findDataChannel(stream))
+		channel->triggerBufferedAmount(amount);
+}
+
+shared_ptr<DataChannel> PeerConnection::emplaceDataChannel(Description::Role role,
+                                                           const string &label,
+                                                           const string &protocol,
+                                                           const Reliability &reliability) {
+	// The active side must use streams with even identifiers, whereas the passive side must use
+	// streams with odd identifiers.
+	// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09#section-6
+	std::unique_lock lock(mDataChannelsMutex);
+	unsigned int stream = (role == Description::Role::Active) ? 0 : 1;
+	while (mDataChannels.find(stream) != mDataChannels.end()) {
+		stream += 2;
+		if (stream >= 65535)
+			throw std::runtime_error("Too many DataChannels");
+	}
+	auto channel =
+	    std::make_shared<DataChannel>(shared_from_this(), stream, label, protocol, reliability);
+	mDataChannels.emplace(std::make_pair(stream, channel));
+	return channel;
+}
+
+shared_ptr<DataChannel> PeerConnection::findDataChannel(uint16_t stream) {
+	std::shared_lock lock(mDataChannelsMutex);
 	shared_ptr<DataChannel> channel;
 	shared_ptr<DataChannel> channel;
 	if (auto it = mDataChannels.find(stream); it != mDataChannels.end()) {
 	if (auto it = mDataChannels.find(stream); it != mDataChannels.end()) {
 		channel = it->second.lock();
 		channel = it->second.lock();
 		if (!channel || channel->isClosed()) {
 		if (!channel || channel->isClosed()) {
 			mDataChannels.erase(it);
 			mDataChannels.erase(it);
-			channel = nullptr;
+			channel.reset();
 		}
 		}
 	}
 	}
-
-	if (channel)
-		channel->triggerBufferedAmount(amount);
+	return channel;
 }
 }
 
 
 void PeerConnection::iterateDataChannels(
 void PeerConnection::iterateDataChannels(
     std::function<void(shared_ptr<DataChannel> channel)> func) {
     std::function<void(shared_ptr<DataChannel> channel)> func) {
+	std::shared_lock lock(mDataChannelsMutex);
 	auto it = mDataChannels.begin();
 	auto it = mDataChannels.begin();
 	while (it != mDataChannels.end()) {
 	while (it != mDataChannels.end()) {
 		auto channel = it->second.lock();
 		auto channel = it->second.lock();