Browse Source

Fixed forwardMessage() to open a new channel if an old one exists

Paul-Louis Ageneau 3 years ago
parent
commit
1a2ae09bd6
3 changed files with 40 additions and 25 deletions
  1. 9 1
      src/impl/datachannel.cpp
  2. 2 0
      src/impl/datachannel.hpp
  3. 29 24
      src/impl/peerconnection.cpp

+ 9 - 1
src/impl/datachannel.cpp

@@ -78,6 +78,14 @@ struct CloseMessage {
 LogCounter COUNTER_USERNEG_OPEN_MESSAGE(
 LogCounter COUNTER_USERNEG_OPEN_MESSAGE(
     plog::warning, "Number of open messages for a user-negotiated DataChannel received");
     plog::warning, "Number of open messages for a user-negotiated DataChannel received");
 
 
+bool DataChannel::IsOpenMessage(message_ptr message) {
+	if (message->type != Message::Control)
+		return false;
+
+	auto raw = reinterpret_cast<const uint8_t *>(message->data());
+	return !message->empty() && raw[0] == MESSAGE_OPEN;
+}
+
 DataChannel::DataChannel(weak_ptr<PeerConnection> pc, uint16_t stream, string label,
 DataChannel::DataChannel(weak_ptr<PeerConnection> pc, uint16_t stream, string label,
                          string protocol, Reliability reliability)
                          string protocol, Reliability reliability)
     : mPeerConnection(pc), mStream(stream), mLabel(std::move(label)),
     : mPeerConnection(pc), mStream(stream), mLabel(std::move(label)),
@@ -217,7 +225,7 @@ bool DataChannel::outgoing(message_ptr message) {
 }
 }
 
 
 void DataChannel::incoming(message_ptr message) {
 void DataChannel::incoming(message_ptr message) {
-	if (!message)
+	if (!message || mIsClosed)
 		return;
 		return;
 
 
 	switch (message->type) {
 	switch (message->type) {

+ 2 - 0
src/impl/datachannel.hpp

@@ -35,6 +35,8 @@ namespace rtc::impl {
 struct PeerConnection;
 struct PeerConnection;
 
 
 struct DataChannel : Channel, std::enable_shared_from_this<DataChannel> {
 struct DataChannel : Channel, std::enable_shared_from_this<DataChannel> {
+	static bool IsOpenMessage(message_ptr message);
+
 	DataChannel(weak_ptr<PeerConnection> pc, uint16_t stream, string label, string protocol,
 	DataChannel(weak_ptr<PeerConnection> pc, uint16_t stream, string label, string protocol,
 	            Reliability reliability);
 	            Reliability reliability);
 	virtual ~DataChannel();
 	virtual ~DataChannel();

+ 29 - 24
src/impl/peerconnection.cpp

@@ -419,42 +419,47 @@ void PeerConnection::forwardMessage(message_ptr message) {
 		return;
 		return;
 	}
 	}
 
 
-	uint16_t stream = uint16_t(message->stream);
+	const uint16_t stream = uint16_t(message->stream);
 	auto channel = findDataChannel(stream);
 	auto channel = findDataChannel(stream);
-	if (!channel) {
+
+	if (DataChannel::IsOpenMessage(message)) {
 		auto iceTransport = getIceTransport();
 		auto iceTransport = getIceTransport();
 		auto sctpTransport = getSctpTransport();
 		auto sctpTransport = getSctpTransport();
 		if (!iceTransport || !sctpTransport)
 		if (!iceTransport || !sctpTransport)
 			return;
 			return;
 
 
-		// See https://tools.ietf.org/html/rfc8832
-		const byte dataChannelOpenMessage{0x03};
-		uint16_t remoteParity = (iceTransport->role() == Description::Role::Active) ? 1 : 0;
-		if (message->type == Message::Control) {
-			if (message->size() == 0 || *message->data() != dataChannelOpenMessage)
-				return; // ignore
-
-			if (stream % 2 != remoteParity) {
-				// The odd/even rule is violated, close the DataChannel
-				sctpTransport->closeStream(message->stream);
-				return;
-			}
+		const uint16_t remoteParity = (iceTransport->role() == Description::Role::Active) ? 1 : 0;
+		if (stream % 2 != remoteParity) {
+			// The odd/even rule is violated, close the DataChannel
+			PLOG_WARNING << "Got open message violating the odd/even rule on stream " << stream;
+			sctpTransport->closeStream(message->stream);
+			return;
+		}
 
 
-			channel =
-			    std::make_shared<NegotiatedDataChannel>(weak_from_this(), sctpTransport, stream);
-			channel->openCallback = weak_bind(&PeerConnection::triggerDataChannel, this,
-			                                  weak_ptr<DataChannel>{channel});
+		if (channel && channel->isOpen()) {
+			PLOG_WARNING << "Got open message on stream " << stream
+			             << " for an already open DataChannel, closing it first";
+			channel->close();
+		}
 
 
-			std::unique_lock lock(mDataChannelsMutex); // we are going to emplace
-			mDataChannels.emplace(stream, channel);
+		channel = std::make_shared<NegotiatedDataChannel>(weak_from_this(), sctpTransport, stream);
+		channel->openCallback =
+		    weak_bind(&PeerConnection::triggerDataChannel, this, weak_ptr<DataChannel>{channel});
 
 
-		} else {
-			// Invalid, close the DataChannel
+		std::unique_lock lock(mDataChannelsMutex); // we are going to emplace
+		mDataChannels.emplace(stream, channel);
+	}
+
+	if (!channel) {
+		// Invalid, close the DataChannel
+		PLOG_WARNING << "Got unexpected message on stream " << stream;
+		if (auto sctpTransport = getSctpTransport())
 			sctpTransport->closeStream(message->stream);
 			sctpTransport->closeStream(message->stream);
-			return;
-		}
+
+		return;
 	}
 	}
 
 
+	// Forward the message
 	channel->incoming(message);
 	channel->incoming(message);
 }
 }