Browse Source

Added sanity checks for DataChannel creation

Paul-Louis Ageneau 6 years ago
parent
commit
8cb69c8269
2 changed files with 25 additions and 7 deletions
  1. 10 4
      src/datachannel.cpp
  2. 15 3
      src/peerconnection.cpp

+ 10 - 4
src/datachannel.cpp

@@ -24,6 +24,9 @@ namespace rtc {
 
 using std::shared_ptr;
 
+// Messages for the DataChannel establishment protocol
+// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09
+
 enum MessageType : uint8_t {
 	MESSAGE_OPEN_REQUEST = 0x00,
 	MESSAGE_OPEN_RESPONSE = 0x01,
@@ -75,7 +78,7 @@ void DataChannel::close() {
 }
 
 void DataChannel::send(const std::variant<binary, string> &data) {
-	if (!mSctpTransport)
+	if (mIsClosed || !mSctpTransport)
 		return;
 
 	std::visit(
@@ -83,16 +86,19 @@ void DataChannel::send(const std::variant<binary, string> &data) {
 		    using T = std::decay_t<decltype(d)>;
 		    constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary;
 		    auto *b = reinterpret_cast<const byte *>(d.data());
-		    mSctpTransport->send(make_message(b, b + d.size(), type, mStream, mReliability));
+		    // Before the ACK has been received on a DataChannel, all messages must be sent ordered
+		    auto reliability = mIsOpen ? mReliability : nullptr;
+		    mSctpTransport->send(make_message(b, b + d.size(), type, mStream, reliability));
 	    },
 	    data);
 }
 
 void DataChannel::send(const byte *data, size_t size) {
-	if (!mSctpTransport)
+	if (mIsClosed || !mSctpTransport)
 		return;
 
-	mSctpTransport->send(make_message(data, data + size, Message::Binary, mStream));
+	auto reliability = mIsOpen ? mReliability : nullptr;
+	mSctpTransport->send(make_message(data, data + size, Message::Binary, mStream, reliability));
 }
 
 unsigned int DataChannel::stream() const { return mStream; }

+ 15 - 3
src/peerconnection.cpp

@@ -129,13 +129,25 @@ bool PeerConnection::checkFingerprint(const std::string &fingerprint) const {
 }
 
 void PeerConnection::forwardMessage(message_ptr message) {
+	if (!mIceTransport || !mSctpTransport)
+		throw std::logic_error("Got a DataChannel message without transport");
+
 	shared_ptr<DataChannel> channel;
 	if (auto it = mDataChannels.find(message->stream); it != mDataChannels.end()) {
 		channel = it->second;
 	} else {
-		channel = std::make_shared<DataChannel>(message->stream, mSctpTransport);
-		channel->onOpen(std::bind(&PeerConnection::triggerDataChannel, this, channel));
-		mDataChannels.insert(std::make_pair(message->stream, channel));
+		const byte dataChannelOpenMessage{0x03};
+		unsigned int remoteParity = (mIceTransport->role() == Description::Role::Active) ? 1 : 0;
+		if (message->type == Message::Control && *message->data() == dataChannelOpenMessage &&
+		    message->stream % 2 == remoteParity) {
+			channel = std::make_shared<DataChannel>(message->stream, mSctpTransport);
+			channel->onOpen(std::bind(&PeerConnection::triggerDataChannel, this, channel));
+			mDataChannels.insert(std::make_pair(message->stream, channel));
+		} else {
+			// Invalid, close the DataChannel by resetting the stream
+			mSctpTransport->reset(message->stream);
+			return;
+		}
 	}
 
 	channel->incoming(message);