Pārlūkot izejas kodu

Changed DataChannel stream id assignement to happen at connection

Paul-Louis Ageneau 3 gadi atpakaļ
vecāks
revīzija
e53dd3223e

+ 2 - 2
include/rtc/datachannel.hpp

@@ -39,8 +39,8 @@ public:
 	DataChannel(impl_ptr<impl::DataChannel> impl);
 	virtual ~DataChannel();
 
-	uint16_t stream() const;
-	uint16_t id() const;
+	optional<uint16_t> stream() const;
+	optional<uint16_t> id() const;
 	string label() const;
 	string protocol() const;
 	Reliability reliability() const;

+ 4 - 1
src/capi.cpp

@@ -863,7 +863,10 @@ int rtcDeleteDataChannel(int dc) {
 int rtcGetDataChannelStream(int dc) {
 	return wrap([dc] {
 		auto dataChannel = getDataChannel(dc);
-		return int(dataChannel->id());
+		if (auto stream = dataChannel->stream())
+			return int(*stream);
+		else
+			return RTC_ERR_NOT_AVAIL;
 	});
 }
 

+ 2 - 2
src/datachannel.cpp

@@ -46,9 +46,9 @@ DataChannel::~DataChannel() {
 
 void DataChannel::close() { return impl()->close(); }
 
-uint16_t DataChannel::stream() const { return impl()->stream(); }
+optional<uint16_t> DataChannel::stream() const { return impl()->stream(); }
 
-uint16_t DataChannel::id() const { return impl()->stream(); }
+optional<uint16_t> DataChannel::id() const { return impl()->stream(); }
 
 string DataChannel::label() const { return impl()->label(); }
 

+ 33 - 25
src/impl/datachannel.cpp

@@ -83,10 +83,9 @@ bool DataChannel::IsOpenMessage(message_ptr message) {
 	return !message->empty() && raw[0] == MESSAGE_OPEN;
 }
 
-DataChannel::DataChannel(weak_ptr<PeerConnection> pc, uint16_t stream, string label,
-                         string protocol, Reliability reliability)
-    : mPeerConnection(pc), mStream(stream), mLabel(std::move(label)),
-      mProtocol(std::move(protocol)),
+DataChannel::DataChannel(weak_ptr<PeerConnection> pc, string label, string protocol,
+                         Reliability reliability)
+    : mPeerConnection(pc), mLabel(std::move(label)), mProtocol(std::move(protocol)),
       mReliability(std::make_shared<Reliability>(std::move(reliability))),
       mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {}
 
@@ -105,8 +104,8 @@ void DataChannel::close() {
 		transport = mSctpTransport.lock();
 	}
 
-	if (mIsOpen.exchange(false) && transport)
-		transport->closeStream(mStream);
+	if (mIsOpen.exchange(false) && transport && mStream.has_value())
+		transport->closeStream(mStream.value());
 
 	if (!mIsClosed.exchange(true))
 		triggerClosed();
@@ -152,7 +151,7 @@ optional<message_variant> DataChannel::peek() {
 
 size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); }
 
-uint16_t DataChannel::stream() const {
+optional<uint16_t> DataChannel::stream() const {
 	std::shared_lock lock(mMutex);
 	return mStream;
 }
@@ -181,8 +180,13 @@ size_t DataChannel::maxMessageSize() const {
 	return pc ? pc->remoteMaxMessageSize() : DEFAULT_MAX_MESSAGE_SIZE;
 }
 
-void DataChannel::shiftStream() {
-	// Ignore
+void DataChannel::assignStream(uint16_t stream) {
+	std::unique_lock lock(mMutex);
+
+	if (mStream.has_value())
+		throw std::logic_error("DataChannel already has a stream assigned");
+
+	mStream = stream;
 }
 
 void DataChannel::open(shared_ptr<SctpTransport> transport) {
@@ -208,12 +212,15 @@ bool DataChannel::outgoing(message_ptr message) {
 		if (!transport || mIsClosed)
 			throw std::runtime_error("DataChannel is closed");
 
+		if (!mStream.has_value())
+			throw std::logic_error("DataChannel has no stream assigned");
+
 		if (message->size() > maxMessageSize())
-			throw std::runtime_error("Message size exceeds limit");
+			throw std::invalid_argument("Message size exceeds limit");
 
 		// Before the ACK has been received on a DataChannel, all messages must be sent ordered
 		message->reliability = mIsOpen ? mReliability : nullptr;
-		message->stream = mStream;
+		message->stream = mStream.value();
 	}
 
 	return transport->send(message);
@@ -259,22 +266,19 @@ void DataChannel::incoming(message_ptr message) {
 	}
 }
 
-OutgoingDataChannel::OutgoingDataChannel(weak_ptr<PeerConnection> pc, uint16_t stream,
-                                             string label, string protocol, Reliability reliability)
-    : DataChannel(pc, stream, std::move(label), std::move(protocol), std::move(reliability)) {}
+OutgoingDataChannel::OutgoingDataChannel(weak_ptr<PeerConnection> pc, string label, string protocol,
+                                         Reliability reliability)
+    : DataChannel(pc, std::move(label), std::move(protocol), std::move(reliability)) {}
 
 OutgoingDataChannel::~OutgoingDataChannel() {}
 
-void OutgoingDataChannel::shiftStream() {
-	std::shared_lock lock(mMutex);
-	if (mStream % 2 == 1)
-		mStream -= 1;
-}
-
 void OutgoingDataChannel::open(shared_ptr<SctpTransport> transport) {
 	std::unique_lock lock(mMutex);
 	mSctpTransport = transport;
 
+	if (!mStream.has_value())
+		throw std::runtime_error("DataChannel has no stream assigned");
+
 	uint8_t channelType;
 	uint32_t reliabilityParameter;
 	switch (mReliability->type) {
@@ -313,7 +317,7 @@ void OutgoingDataChannel::open(shared_ptr<SctpTransport> transport) {
 
 	lock.unlock();
 
-	transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
+	transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream.value()));
 }
 
 void OutgoingDataChannel::processOpenMessage(message_ptr) {
@@ -321,8 +325,9 @@ void OutgoingDataChannel::processOpenMessage(message_ptr) {
 }
 
 IncomingDataChannel::IncomingDataChannel(weak_ptr<PeerConnection> pc,
-                                         weak_ptr<SctpTransport> transport, uint16_t stream)
-    : DataChannel(pc, stream, "", "", {}) {
+                                         weak_ptr<SctpTransport> transport)
+    : DataChannel(pc, "", "", {}) {
+
 	mSctpTransport = transport;
 }
 
@@ -336,7 +341,10 @@ void IncomingDataChannel::processOpenMessage(message_ptr message) {
 	std::unique_lock lock(mMutex);
 	auto transport = mSctpTransport.lock();
 	if (!transport)
-		throw std::runtime_error("DataChannel has no transport");
+		throw std::logic_error("DataChannel has no transport");
+
+	if (!mStream.has_value())
+		throw std::logic_error("DataChannel has no stream assigned");
 
 	if (message->size() < sizeof(OpenMessage))
 		throw std::invalid_argument("DataChannel open message too small");
@@ -375,7 +383,7 @@ void IncomingDataChannel::processOpenMessage(message_ptr message) {
 	auto &ack = *reinterpret_cast<AckMessage *>(buffer.data());
 	ack.type = MESSAGE_ACK;
 
-	transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
+	transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream.value()));
 
 	if (!mIsOpen.exchange(true))
 		triggerOpen();

+ 10 - 11
src/impl/datachannel.hpp

@@ -37,7 +37,7 @@ struct PeerConnection;
 struct DataChannel : Channel, std::enable_shared_from_this<DataChannel> {
 	static bool IsOpenMessage(message_ptr message);
 
-	DataChannel(weak_ptr<PeerConnection> pc, uint16_t stream, string label, string protocol,
+	DataChannel(weak_ptr<PeerConnection> pc, string label, string protocol,
 	            Reliability reliability);
 	virtual ~DataChannel();
 
@@ -50,7 +50,7 @@ struct DataChannel : Channel, std::enable_shared_from_this<DataChannel> {
 	optional<message_variant> peek() override;
 	size_t availableAmount() const override;
 
-	uint16_t stream() const;
+	optional<uint16_t> stream() const;
 	string label() const;
 	string protocol() const;
 	Reliability reliability() const;
@@ -59,7 +59,7 @@ struct DataChannel : Channel, std::enable_shared_from_this<DataChannel> {
 	bool isClosed(void) const;
 	size_t maxMessageSize() const;
 
-	virtual void shiftStream();
+	virtual void assignStream(uint16_t stream);
 	virtual void open(shared_ptr<SctpTransport> transport);
 	virtual void processOpenMessage(message_ptr);
 
@@ -67,32 +67,31 @@ protected:
 	const weak_ptr<impl::PeerConnection> mPeerConnection;
 	weak_ptr<SctpTransport> mSctpTransport;
 
-	uint16_t mStream;
+	optional<uint16_t> mStream;
 	string mLabel;
 	string mProtocol;
 	shared_ptr<Reliability> mReliability;
 
 	mutable std::shared_mutex mMutex;
 
-	Queue<message_ptr> mRecvQueue;
-
 	std::atomic<bool> mIsOpen = false;
 	std::atomic<bool> mIsClosed = false;
+
+private:
+	Queue<message_ptr> mRecvQueue;
 };
 
 struct OutgoingDataChannel final : public DataChannel {
-	OutgoingDataChannel(weak_ptr<PeerConnection> pc, uint16_t stream, string label,
-	                      string protocol, Reliability reliability);
+	OutgoingDataChannel(weak_ptr<PeerConnection> pc, string label, string protocol,
+	                    Reliability reliability);
 	~OutgoingDataChannel();
 
-	void shiftStream() override;
 	void open(shared_ptr<SctpTransport> transport) override;
 	void processOpenMessage(message_ptr message) override;
 };
 
 struct IncomingDataChannel final : public DataChannel {
-	IncomingDataChannel(weak_ptr<PeerConnection> pc, weak_ptr<SctpTransport> transport,
-	                    uint16_t stream);
+	IncomingDataChannel(weak_ptr<PeerConnection> pc, weak_ptr<SctpTransport> transport);
 	~IncomingDataChannel();
 
 	void open(shared_ptr<SctpTransport> transport) override;

+ 74 - 64
src/impl/peerconnection.cpp

@@ -289,9 +289,6 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
 		ports.local = local->application()->sctpPort().value_or(DEFAULT_SCTP_PORT);
 		ports.remote = remote->application()->sctpPort().value_or(DEFAULT_SCTP_PORT);
 
-		// This is the last occasion to ensure the stream numbers are coherent with the role
-		shiftDataChannels();
-
 		auto transport = std::make_shared<SctpTransport>(
 		    lower, config, std::move(ports), weak_bind(&PeerConnection::forwardMessage, this, _1),
 		    weak_bind(&PeerConnection::forwardBufferedAmount, this, _1, _2),
@@ -302,6 +299,7 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
 			    switch (transportState) {
 			    case SctpTransport::State::Connected:
 				    changeState(State::Connected);
+				    assignDataChannels();
 				    mProcessor.enqueue(&PeerConnection::openDataChannels, shared_from_this());
 				    break;
 			    case SctpTransport::State::Failed:
@@ -446,7 +444,8 @@ void PeerConnection::forwardMessage(message_ptr message) {
 			channel->close();
 		}
 
-		channel = std::make_shared<IncomingDataChannel>(weak_from_this(), sctpTransport, stream);
+		channel = std::make_shared<IncomingDataChannel>(weak_from_this(), sctpTransport);
+		channel->assignStream(stream);
 		channel->openCallback =
 		    weak_bind(&PeerConnection::triggerDataChannel, this, weak_ptr<DataChannel>{channel});
 
@@ -600,46 +599,38 @@ void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
 shared_ptr<DataChannel> PeerConnection::emplaceDataChannel(string label, DataChannelInit init) {
 	cleanupDataChannels();
 	std::unique_lock lock(mDataChannelsMutex); // we are going to emplace
-	const uint16_t maxStream = maxDataChannelStream();
-	uint16_t stream;
-	if (init.id) {
-		stream = *init.id;
-		if (stream > maxStream)
-			throw std::invalid_argument("DataChannel stream id is too high");
-	} else {
-		// RFC 5763: The answerer MUST use either a setup attribute value of setup:active or
-		// setup:passive. [...] Thus, setup:active is RECOMMENDED.
-		// See https://www.rfc-editor.org/rfc/rfc5763.html#section-5
-		// Therefore, we assume passive role if we are the offerer.
-		auto iceTransport = getIceTransport();
-		auto role = iceTransport ? iceTransport->role() : Description::Role::Passive;
-
-		// RFC 8832: The peer that initiates opening a data channel selects a stream identifier for
-		// which the corresponding incoming and outgoing streams are unused.  If the side is acting
-		// as the DTLS client, it MUST choose an even stream identifier; if the side is acting as
-		// the DTLS server, it MUST choose an odd one.
-		// See https://www.rfc-editor.org/rfc/rfc8832.html#section-6
-		stream = (role == Description::Role::Active) ? 0 : 1;
-		while (true) {
-			if (stream > maxStream)
-				throw std::runtime_error("Too many DataChannels");
 
-			auto it = mDataChannels.find(stream);
-			if (it == mDataChannels.end() || !it->second.lock())
-				break;
-
-			stream += 2;
-		}
-	}
 	// If the DataChannel is user-negotiated, do not negotiate it in-band
 	auto channel =
 	    init.negotiated
-	        ? std::make_shared<DataChannel>(weak_from_this(), stream, std::move(label),
+	        ? std::make_shared<DataChannel>(weak_from_this(), std::move(label),
 	                                        std::move(init.protocol), std::move(init.reliability))
-	        : std::make_shared<OutgoingDataChannel>(weak_from_this(), stream, std::move(label),
+	        : std::make_shared<OutgoingDataChannel>(weak_from_this(), std::move(label),
 	                                                std::move(init.protocol),
 	                                                std::move(init.reliability));
-	mDataChannels.emplace(std::make_pair(stream, channel));
+
+	// If the user supplied a stream id, use it, otherwise assign it later
+	if (init.id) {
+		uint16_t stream = *init.id;
+		if (stream > maxDataChannelStream())
+			throw std::invalid_argument("DataChannel stream id is too high");
+
+		channel->assignStream(stream);
+		mDataChannels.emplace(std::make_pair(stream, channel));
+
+	} else {
+		mUnassignedDataChannels.push_back(channel);
+	}
+
+	lock.unlock(); // we are going to call assignDataChannels()
+
+	// If SCTP is connected, assign and open now
+	auto sctpTransport = getSctpTransport();
+	if (sctpTransport && sctpTransport->state() == SctpTransport::State::Connected) {
+		assignDataChannels();
+		channel->open(sctpTransport);
+	}
+
 	return channel;
 }
 
@@ -657,21 +648,43 @@ uint16_t PeerConnection::maxDataChannelStream() const {
 	return sctpTransport ? sctpTransport->maxStream() : (MAX_SCTP_STREAMS_COUNT - 1);
 }
 
-void PeerConnection::shiftDataChannels() {
-	auto iceTransport = std::atomic_load(&mIceTransport);
-	auto sctpTransport = std::atomic_load(&mSctpTransport);
-	if (!sctpTransport && iceTransport && iceTransport->role() == Description::Role::Active) {
-		std::unique_lock lock(mDataChannelsMutex); // we are going to swap the container
-		decltype(mDataChannels) newDataChannels;
-		auto it = mDataChannels.begin();
-		while (it != mDataChannels.end()) {
-			auto channel = it->second.lock();
-			channel->shiftStream();
-			newDataChannels.emplace(channel->stream(), channel);
-			++it;
+void PeerConnection::assignDataChannels() {
+	std::unique_lock lock(mDataChannelsMutex); // we are going to emplace
+
+	auto iceTransport = getIceTransport();
+	if (!iceTransport)
+		throw std::logic_error("Attempted to assign DataChannels without ICE transport");
+
+	const uint16_t maxStream = maxDataChannelStream();
+	for (auto it = mUnassignedDataChannels.begin(); it != mUnassignedDataChannels.end(); ++it) {
+		auto channel = it->lock();
+		if (!channel)
+			continue;
+
+		// RFC 8832: The peer that initiates opening a data channel selects a stream identifier
+		// for which the corresponding incoming and outgoing streams are unused.  If the side is
+		// acting as the DTLS client, it MUST choose an even stream identifier; if the side is
+		// acting as the DTLS server, it MUST choose an odd one. See
+		// https://www.rfc-editor.org/rfc/rfc8832.html#section-6
+		uint16_t stream = (iceTransport->role() == Description::Role::Active) ? 0 : 1;
+		while (true) {
+			if (stream > maxStream)
+				throw std::runtime_error("Too many DataChannels");
+
+			auto it = mDataChannels.find(stream);
+			if (it == mDataChannels.end() || !it->second.lock())
+				break;
+
+			stream += 2;
 		}
-		std::swap(mDataChannels, newDataChannels);
+
+		PLOG_DEBUG << "Assigning stream " << stream  << " to DataChannel";
+
+		channel->assignStream(stream);
+		mDataChannels.emplace(std::make_pair(stream, channel));
 	}
+
+	mUnassignedDataChannels.clear();
 }
 
 void PeerConnection::iterateDataChannels(
@@ -690,8 +703,13 @@ void PeerConnection::iterateDataChannels(
 		}
 	}
 
-	for (auto &channel : locked)
-		func(std::move(channel));
+	for (auto &channel : locked) {
+		try {
+			func(std::move(channel));
+		} catch (const std::exception &e) {
+			PLOG_WARNING << e.what();
+		}
+	}
 }
 
 void PeerConnection::cleanupDataChannels() {
@@ -710,13 +728,8 @@ void PeerConnection::cleanupDataChannels() {
 void PeerConnection::openDataChannels() {
 	if (auto transport = std::atomic_load(&mSctpTransport))
 		iterateDataChannels([&](shared_ptr<DataChannel> channel) {
-			// Check again as the maximum might have been negotiated lower
-			if (channel->stream() <= transport->maxStream()) {
+			if (!channel->isOpen())
 				channel->open(transport);
-			} else {
-				channel->triggerError("DataChannel stream id is too high");
-				channel->remoteClose();
-			}
 		});
 }
 
@@ -841,7 +854,7 @@ void PeerConnection::processLocalDescription(Description description) {
 			    rtc::overloaded{
 			        [&](Description::Application *remoteApp) {
 				        std::shared_lock lock(mDataChannelsMutex);
-				        if (!mDataChannels.empty()) {
+				        if (!mDataChannels.empty() || !mUnassignedDataChannels.empty()) {
 					        // Prefer local description
 					        Description::Application app(remoteApp->mid());
 					        app.setSctpPort(localSctpPort);
@@ -931,11 +944,12 @@ void PeerConnection::processLocalDescription(Description description) {
 		// Add application for data channels
 		if (!description.hasApplication()) {
 			std::shared_lock lock(mDataChannelsMutex);
-			if (!mDataChannels.empty()) {
+			if (!mDataChannels.empty() || !mUnassignedDataChannels.empty()) {
 				// Prevents mid collision with remote or local tracks
 				unsigned int m = 0;
 				while (description.hasMid(std::to_string(m)))
 					++m;
+
 				Description::Application app(std::to_string(m));
 				app.setSctpPort(localSctpPort);
 				app.setMaxMessageSize(localMaxMessageSize);
@@ -1023,10 +1037,6 @@ void PeerConnection::processRemoteDescription(Description description) {
 
 	iceTransport->setRemoteDescription(std::move(description));
 
-	// Since we assumed passive role during DataChannel creation, we might need to shift the stream
-	// numbers from odd to even.
-	shiftDataChannels();
-
 	if (description.hasApplication()) {
 		auto dtlsTransport = std::atomic_load(&mDtlsTransport);
 		auto sctpTransport = std::atomic_load(&mSctpTransport);

+ 5 - 2
src/impl/peerconnection.hpp

@@ -70,7 +70,7 @@ struct PeerConnection : std::enable_shared_from_this<PeerConnection> {
 	shared_ptr<DataChannel> emplaceDataChannel(string label, DataChannelInit init);
 	shared_ptr<DataChannel> findDataChannel(uint16_t stream);
 	uint16_t maxDataChannelStream() const;
-	void shiftDataChannels();
+	void assignDataChannels();
 	void iterateDataChannels(std::function<void(shared_ptr<DataChannel> channel)> func);
 	void cleanupDataChannels();
 	void openDataChannels();
@@ -138,9 +138,12 @@ private:
 	shared_ptr<SctpTransport> mSctpTransport;
 
 	std::unordered_map<uint16_t, weak_ptr<DataChannel>> mDataChannels; // by stream ID
+	std::vector<weak_ptr<DataChannel>> mUnassignedDataChannels;
+	std::shared_mutex mDataChannelsMutex;
+
 	std::unordered_map<string, weak_ptr<Track>> mTracks;               // by mid
 	std::vector<weak_ptr<Track>> mTrackLines;                          // by SDP order
-	std::shared_mutex mDataChannelsMutex, mTracksMutex;
+	std::shared_mutex mTracksMutex;
 
 	Queue<shared_ptr<DataChannel>> mPendingDataChannels;
 	Queue<shared_ptr<Track>> mPendingTracks;

+ 0 - 4
src/peerconnection.cpp

@@ -265,10 +265,6 @@ shared_ptr<DataChannel> PeerConnection::createDataChannel(string label, DataChan
 	auto channelImpl = impl()->emplaceDataChannel(std::move(label), std::move(init));
 	auto channel = std::make_shared<DataChannel>(channelImpl);
 
-	if (auto transport = impl()->getSctpTransport())
-		if (transport->state() == impl::SctpTransport::State::Connected)
-			channelImpl->open(transport);
-
 	// Renegotiation is needed iff the current local description does not have application
 	auto local = impl()->localDescription();
 	if (!local || !local->hasApplication())