Browse Source

Ignore messages to registered but destroyed DataChannel

Paul-Louis Ageneau 2 years ago
parent
commit
5419df3204
2 changed files with 32 additions and 22 deletions
  1. 30 21
      src/impl/peerconnection.cpp
  2. 2 1
      src/impl/peerconnection.hpp

+ 30 - 21
src/impl/peerconnection.cpp

@@ -429,21 +429,22 @@ void PeerConnection::forwardMessage(message_ptr message) {
 		return;
 
 	const uint16_t stream = uint16_t(message->stream);
-	auto channel = findDataChannel(stream);
+	auto [channel, found] = findDataChannel(stream);
 
 	if (DataChannel::IsOpenMessage(message)) {
 		const uint16_t remoteParity = (iceTransport->role() == Description::Role::Active) ? 1 : 0;
 		if (stream % 2 != remoteParity) {
-			// The odd/even rule is violated, close the DataChannel
+			// The odd/even rule is violated, the receiver must close the DataChannel
 			PLOG_WARNING << "Got open message violating the odd/even rule on stream " << stream;
 			sctpTransport->closeStream(message->stream);
 			return;
 		}
 
-		if (channel && channel->isOpen()) {
-			PLOG_WARNING << "Got open message on stream " << stream
-			             << " for an already open DataChannel, closing it first";
-			channel->close();
+		if (found) {
+			// The stream is already used, the receiver must close the DataChannel
+			PLOG_WARNING << "Got open message on already used stream " << stream;
+			sctpTransport->closeStream(message->stream);
+			return;
 		}
 
 		channel = std::make_shared<IncomingDataChannel>(weak_from_this(), sctpTransport);
@@ -454,14 +455,7 @@ void PeerConnection::forwardMessage(message_ptr message) {
 		std::unique_lock lock(mDataChannelsMutex); // we are going to emplace
 		mDataChannels.emplace(stream, channel);
 	}
-
-	if (message->type == Message::Reset) {
-		// Incoming stream is reset, unregister it
-		std::unique_lock lock(mDataChannelsMutex); // we are going to erase
-		mDataChannels.erase(stream);
-	}
-
-	if (!channel) {
+	else if (!found) {
 		if (message->type == Message::Reset)
 			return; // ignore
 
@@ -471,8 +465,18 @@ void PeerConnection::forwardMessage(message_ptr message) {
 		return;
 	}
 
-	// Forward the message
-	channel->incoming(message);
+	if (message->type == Message::Reset) {
+		// Incoming stream is reset, unregister it
+		removeDataChannel(stream);
+	}
+
+	if (channel) {
+		// Forward the message
+		channel->incoming(message);
+	} else {
+		// DataChannel was destroyed, ignore
+		PLOG_DEBUG << "Ignored message on stream " << stream << ", DataChannel is destroyed";
+	}
 }
 
 void PeerConnection::forwardMedia(message_ptr message) {
@@ -558,7 +562,8 @@ void PeerConnection::forwardMedia(message_ptr message) {
 }
 
 void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
-	if (auto channel = findDataChannel(stream))
+	[[maybe_unused]] auto [channel, found] = findDataChannel(stream);
+	if (channel)
 		channel->triggerBufferedAmount(amount);
 }
 
@@ -599,13 +604,17 @@ shared_ptr<DataChannel> PeerConnection::emplaceDataChannel(string label, DataCha
 	return channel;
 }
 
-shared_ptr<DataChannel> PeerConnection::findDataChannel(uint16_t stream) {
+std::pair<shared_ptr<DataChannel>, bool> PeerConnection::findDataChannel(uint16_t stream) {
 	std::shared_lock lock(mDataChannelsMutex); // read-only
 	if (auto it = mDataChannels.find(stream); it != mDataChannels.end())
-		if (auto channel = it->second.lock())
-			return channel;
+		return std::make_pair(it->second.lock(), true);
+	else
+		return std::make_pair(nullptr, false);
+}
 
-	return nullptr;
+bool PeerConnection::removeDataChannel(uint16_t stream) {
+		std::unique_lock lock(mDataChannelsMutex); // we are going to erase
+		return mDataChannels.erase(stream) != 0;
 }
 
 uint16_t PeerConnection::maxDataChannelStream() const {

+ 2 - 1
src/impl/peerconnection.hpp

@@ -58,7 +58,8 @@ struct PeerConnection : std::enable_shared_from_this<PeerConnection> {
 	void forwardBufferedAmount(uint16_t stream, size_t amount);
 
 	shared_ptr<DataChannel> emplaceDataChannel(string label, DataChannelInit init);
-	shared_ptr<DataChannel> findDataChannel(uint16_t stream);
+	std::pair<shared_ptr<DataChannel>, bool> findDataChannel(uint16_t stream);
+	bool removeDataChannel(uint16_t stream);
 	uint16_t maxDataChannelStream() const;
 	void assignDataChannels();
 	void iterateDataChannels(std::function<void(shared_ptr<DataChannel> channel)> func);