Browse Source

Merge branch 'paullouisageneau:master' into cheungxiongwei

cheungxiongwei 3 years ago
parent
commit
a28d8153fb

+ 1 - 1
CMakeLists.txt

@@ -1,6 +1,6 @@
 cmake_minimum_required(VERSION 3.7)
 project(libdatachannel
-	VERSION 0.17.0
+	VERSION 0.17.1
 	LANGUAGES CXX)
 set(PROJECT_DESCRIPTION "C/C++ WebRTC network library featuring Data Channels, Media Transport, and WebSockets")
 

+ 1 - 0
include/rtc/peerconnection.hpp

@@ -85,6 +85,7 @@ public:
 	optional<Description> remoteDescription() const;
 	optional<string> localAddress() const;
 	optional<string> remoteAddress() const;
+	uint16_t maxDataChannelId() const;
 	bool getSelectedCandidatePair(Candidate *local, Candidate *remote);
 
 	void setLocalDescription(Description::Type type = Description::Type::Unspec);

+ 71 - 65
src/impl/dtlssrtptransport.cpp

@@ -146,16 +146,73 @@ bool DtlsSrtpTransport::sendMedia(message_ptr message) {
 	return Transport::outgoing(message); // bypass DTLS DSCP marking
 }
 
-void DtlsSrtpTransport::incoming(message_ptr message) {
-	if (!mInitDone) {
-		// Bypas
-		DtlsTransport::incoming(message);
+void DtlsSrtpTransport::recvMedia(message_ptr message) {
+	// The RTP header has a minimum size of 12 bytes
+	// An RTCP packet can have a minimum size of 8 bytes
+	int size = int(message->size());
+	if (size < 8) {
+		COUNTER_MEDIA_TRUNCATED++;
+		PLOG_VERBOSE << "Incoming SRTP/SRTCP packet too short, size=" << size;
 		return;
 	}
 
-	int size = int(message->size());
-	if (size == 0)
-		return;
+	uint8_t value2 = to_integer<uint8_t>(*(message->begin() + 1)) & 0x7F;
+	PLOG_VERBOSE << "Demultiplexing SRTCP and SRTP with RTP payload type, value="
+	             << unsigned(value2);
+
+	// See RFC 5761 reference above
+	if (value2 >= 64 && value2 <= 95) { // Range 64-95 (inclusive) MUST be RTCP
+		PLOG_VERBOSE << "Incoming SRTCP packet, size=" << size;
+		if (srtp_err_status_t err = srtp_unprotect_rtcp(mSrtpIn, message->data(), &size)) {
+			if (err == srtp_err_status_replay_fail) {
+				PLOG_VERBOSE << "Incoming SRTCP packet is a replay";
+				COUNTER_SRTCP_REPLAY++;
+			} else if (err == srtp_err_status_auth_fail) {
+				PLOG_VERBOSE << "Incoming SRTCP packet failed authentication check";
+				COUNTER_SRTCP_AUTH_FAIL++;
+			} else {
+				PLOG_VERBOSE << "SRTCP unprotect error, status=" << err;
+				COUNTER_SRTCP_FAIL++;
+			}
+
+			return;
+		}
+		PLOG_VERBOSE << "Unprotected SRTCP packet, size=" << size;
+		message->type = Message::Control;
+		message->stream = reinterpret_cast<RtcpSr *>(message->data())->senderSSRC();
+
+	} else {
+		PLOG_VERBOSE << "Incoming SRTP packet, size=" << size;
+		if (srtp_err_status_t err = srtp_unprotect(mSrtpIn, message->data(), &size)) {
+			if (err == srtp_err_status_replay_fail) {
+				PLOG_VERBOSE << "Incoming SRTP packet is a replay";
+				COUNTER_SRTP_REPLAY++;
+			} else if (err == srtp_err_status_auth_fail) {
+				PLOG_VERBOSE << "Incoming SRTP packet failed authentication check";
+				COUNTER_SRTP_AUTH_FAIL++;
+			} else {
+				PLOG_VERBOSE << "SRTP unprotect error, status=" << err;
+				COUNTER_SRTP_FAIL++;
+			}
+			return;
+		}
+		PLOG_VERBOSE << "Unprotected SRTP packet, size=" << size;
+		message->type = Message::Binary;
+		message->stream = reinterpret_cast<RtpHeader *>(message->data())->ssrc();
+	}
+
+	message->resize(size);
+	mSrtpRecvCallback(message);
+}
+
+bool DtlsSrtpTransport::demuxMessage(message_ptr message) {
+	if (!mInitDone) {
+		// Bypass
+		return false;
+	}
+
+	if (message->size() == 0)
+		return false;
 
 	// RFC 5764 5.1.2. Reception
 	// https://www.rfc-editor.org/rfc/rfc5764.html#section-5.1.2
@@ -167,69 +224,18 @@ void DtlsSrtpTransport::incoming(message_ptr message) {
 	             << unsigned(value1);
 
 	if (value1 >= 20 && value1 <= 63) {
-		PLOG_VERBOSE << "Incoming DTLS packet, size=" << size;
-		DtlsTransport::incoming(message);
+		PLOG_VERBOSE << "Incoming DTLS packet, size=" << message->size();
+		return false;
 
 	} else if (value1 >= 128 && value1 <= 191) {
-		// The RTP header has a minimum size of 12 bytes
-		// An RTCP packet can have a minimum size of 8 bytes
-		if (size < 8) {
-			COUNTER_MEDIA_TRUNCATED++;
-			PLOG_VERBOSE << "Incoming SRTP/SRTCP packet too short, size=" << size;
-			return;
-		}
-
-		uint8_t value2 = to_integer<uint8_t>(*(message->begin() + 1)) & 0x7F;
-		PLOG_VERBOSE << "Demultiplexing SRTCP and SRTP with RTP payload type, value="
-		             << unsigned(value2);
-
-		// See RFC 5761 reference above
-		if (value2 >= 64 && value2 <= 95) { // Range 64-95 (inclusive) MUST be RTCP
-			PLOG_VERBOSE << "Incoming SRTCP packet, size=" << size;
-			if (srtp_err_status_t err = srtp_unprotect_rtcp(mSrtpIn, message->data(), &size)) {
-				if (err == srtp_err_status_replay_fail) {
-					PLOG_VERBOSE << "Incoming SRTCP packet is a replay";
-					COUNTER_SRTCP_REPLAY++;
-				} else if (err == srtp_err_status_auth_fail) {
-					PLOG_VERBOSE << "Incoming SRTCP packet failed authentication check";
-					COUNTER_SRTCP_AUTH_FAIL++;
-				} else {
-					PLOG_VERBOSE << "SRTCP unprotect error, status=" << err;
-					COUNTER_SRTCP_FAIL++;
-				}
-
-				return;
-			}
-			PLOG_VERBOSE << "Unprotected SRTCP packet, size=" << size;
-			message->type = Message::Control;
-			message->stream = reinterpret_cast<RtcpSr *>(message->data())->senderSSRC();
-
-		} else {
-			PLOG_VERBOSE << "Incoming SRTP packet, size=" << size;
-			if (srtp_err_status_t err = srtp_unprotect(mSrtpIn, message->data(), &size)) {
-				if (err == srtp_err_status_replay_fail) {
-					PLOG_VERBOSE << "Incoming SRTP packet is a replay";
-					COUNTER_SRTP_REPLAY++;
-				} else if (err == srtp_err_status_auth_fail) {
-					PLOG_VERBOSE << "Incoming SRTP packet failed authentication check";
-					COUNTER_SRTP_AUTH_FAIL++;
-				} else {
-					PLOG_VERBOSE << "SRTP unprotect error, status=" << err;
-					COUNTER_SRTP_FAIL++;
-				}
-				return;
-			}
-			PLOG_VERBOSE << "Unprotected SRTP packet, size=" << size;
-			message->type = Message::Binary;
-			message->stream = reinterpret_cast<RtpHeader *>(message->data())->ssrc();
-		}
-
-		message->resize(size);
-		mSrtpRecvCallback(message);
+		recvMedia(std::move(message));
+		return true;
 
 	} else {
 		COUNTER_UNKNOWN_PACKET_TYPE++;
-		PLOG_VERBOSE << "Unknown packet type, value=" << unsigned(value1) << ", size=" << size;
+		PLOG_VERBOSE << "Unknown packet type, value=" << unsigned(value1)
+		             << ", size=" << message->size();
+		return true;
 	}
 }
 

+ 2 - 1
src/impl/dtlssrtptransport.hpp

@@ -47,7 +47,8 @@ public:
 	bool sendMedia(message_ptr message);
 
 private:
-	void incoming(message_ptr message) override;
+	void recvMedia(message_ptr message);
+	bool demuxMessage(message_ptr message) override;
 	void postHandshake() override;
 
 	message_callback mSrtpRecvCallback;

+ 17 - 1
src/impl/dtlstransport.cpp

@@ -157,6 +157,11 @@ bool DtlsTransport::outgoing(message_ptr message) {
 	return Transport::outgoing(std::move(message));
 }
 
+bool DtlsTransport::demuxMessage(message_ptr) {
+	// Dummy
+	return false;
+}
+
 void DtlsTransport::postHandshake() {
 	// Dummy
 }
@@ -298,8 +303,11 @@ ssize_t DtlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *dat
 ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
 	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
 	try {
-		if (auto next = t->mIncomingQueue.pop()) {
+		while (auto next = t->mIncomingQueue.pop()) {
 			message_ptr message = std::move(*next);
+			if (t->demuxMessage(message))
+				continue;
+
 			ssize_t len = std::min(maxlen, message->size());
 			std::memcpy(data, message->data(), len);
 			gnutls_transport_set_errno(t->mSession, 0);
@@ -504,6 +512,11 @@ bool DtlsTransport::outgoing(message_ptr message) {
 	return Transport::outgoing(std::move(message));
 }
 
+bool DtlsTransport::demuxMessage(message_ptr) {
+	// Dummy
+	return false;
+}
+
 void DtlsTransport::postHandshake() {
 	// Dummy
 }
@@ -526,6 +539,9 @@ void DtlsTransport::runRecvLoop() {
 			// Process pending messages
 			while (auto next = mIncomingQueue.tryPop()) {
 				message_ptr message = std::move(*next);
+				if (demuxMessage(message))
+					continue;
+
 				BIO_write(mInBio, message->data(), int(message->size()));
 
 				if (state() == State::Connecting) {

+ 1 - 0
src/impl/dtlstransport.hpp

@@ -55,6 +55,7 @@ public:
 protected:
 	virtual void incoming(message_ptr message) override;
 	virtual bool outgoing(message_ptr message) override;
+	virtual bool demuxMessage(message_ptr message);
 	virtual void postHandshake();
 	void runRecvLoop();
 

+ 4 - 0
src/impl/internals.hpp

@@ -44,6 +44,10 @@ const size_t MAX_NUMERICSERV_LEN = 6;  // Max port string representation length
 
 const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default
 
+const uint16_t MAX_SCTP_STREAMS_COUNT = 1024; // Max number of negotiated SCTP streams
+                                              // RFC 8831 recommends 65535 but usrsctp needs a lot
+                                              // of memory, Chromium historically limits to 1024.
+
 const size_t DEFAULT_LOCAL_MAX_MESSAGE_SIZE = 256 * 1024; // Default local max message size
 const size_t DEFAULT_MAX_MESSAGE_SIZE = 65536; // Remote max message size if not specified in SDP
 

+ 23 - 5
src/impl/peerconnection.cpp

@@ -599,11 +599,12 @@ void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
 shared_ptr<DataChannel> PeerConnection::emplaceDataChannel(string label, DataChannelInit init) {
 	cleanupDataChannels();
 	std::unique_lock lock(mDataChannelsMutex); // we are going to emplace
+	const uint16_t maxStream = maxDataChannelStream();
 	uint16_t stream;
 	if (init.id) {
 		stream = *init.id;
-		if (stream == 65535)
-			throw std::invalid_argument("Invalid DataChannel id");
+		if (stream > maxStream)
+			throw std::invalid_argument("DataChannel stream id is too high");
 	} else {
 		// RFC 5763: The answerer MUST use either a setup attribute value of setup:active or
 		// setup:passive. [...] Thus, setup:active is RECOMMENDED.
@@ -618,10 +619,14 @@ shared_ptr<DataChannel> PeerConnection::emplaceDataChannel(string label, DataCha
 		// the DTLS server, it MUST choose an odd one.
 		// See https://www.rfc-editor.org/rfc/rfc8832.html#section-6
 		stream = (role == Description::Role::Active) ? 0 : 1;
-		while (mDataChannels.find(stream) != mDataChannels.end()) {
-			if (stream >= 65535 - 2)
+		while (true) {
+			if (stream > maxStream)
 				throw std::runtime_error("Too many DataChannels");
 
+			auto it = mDataChannels.find(stream);
+			if (it == mDataChannels.end() || !it->second.lock())
+				break;
+
 			stream += 2;
 		}
 	}
@@ -646,6 +651,11 @@ shared_ptr<DataChannel> PeerConnection::findDataChannel(uint16_t stream) {
 	return nullptr;
 }
 
+uint16_t PeerConnection::maxDataChannelStream() const {
+	auto sctpTransport = std::atomic_load(&mSctpTransport);
+	return sctpTransport ? sctpTransport->maxStream() : (MAX_SCTP_STREAMS_COUNT - 1);
+}
+
 void PeerConnection::shiftDataChannels() {
 	auto iceTransport = std::atomic_load(&mIceTransport);
 	auto sctpTransport = std::atomic_load(&mSctpTransport);
@@ -698,7 +708,15 @@ void PeerConnection::cleanupDataChannels() {
 
 void PeerConnection::openDataChannels() {
 	if (auto transport = std::atomic_load(&mSctpTransport))
-		iterateDataChannels([&](shared_ptr<DataChannel> channel) { channel->open(transport); });
+		iterateDataChannels([&](shared_ptr<DataChannel> channel) {
+			// Check again as the maximum might have been negotiated lower
+			if (channel->stream() <= transport->maxStream()) {
+				channel->open(transport);
+			} else {
+				channel->triggerError("DataChannel stream id is too high");
+				channel->remoteClose();
+			}
+		});
 }
 
 void PeerConnection::closeDataChannels() {

+ 1 - 0
src/impl/peerconnection.hpp

@@ -68,6 +68,7 @@ struct PeerConnection : std::enable_shared_from_this<PeerConnection> {
 
 	shared_ptr<DataChannel> emplaceDataChannel(string label, DataChannelInit init);
 	shared_ptr<DataChannel> findDataChannel(uint16_t stream);
+	uint16_t maxDataChannelStream() const;
 	void shiftDataChannels();
 	void iterateDataChannels(std::function<void(shared_ptr<DataChannel> channel)> func);
 	void cleanupDataChannels();

+ 19 - 5
src/impl/sctptransport.cpp

@@ -281,9 +281,12 @@ SctpTransport::SctpTransport(shared_ptr<Transport> lower, const Configuration &c
 	// The number of streams negotiated during SCTP association setup SHOULD be 65535, which is the
 	// maximum number of streams that can be negotiated during the association setup.
 	// See https://www.rfc-editor.org/rfc/rfc8831.html#section-6.2
+	// However, usrsctp allocates tables to hold the stream states. For 65535 streams, it results in
+	// the waste of a few MBs for each association. Therefore, we use a lower limit to save memory.
+	// See https://github.com/sctplab/usrsctp/issues/121
 	struct sctp_initmsg sinit = {};
-	sinit.sinit_num_ostreams = 65535;
-	sinit.sinit_max_instreams = 65535;
+	sinit.sinit_num_ostreams = MAX_SCTP_STREAMS_COUNT;
+	sinit.sinit_max_instreams = MAX_SCTP_STREAMS_COUNT;
 	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_INITMSG, &sinit, sizeof(sinit)))
 		throw std::runtime_error("Could not set socket option SCTP_INITMSG, errno=" +
 		                         std::to_string(errno));
@@ -450,6 +453,11 @@ void SctpTransport::closeStream(unsigned int stream) {
 	mProcessor.enqueue(&SctpTransport::flush, shared_from_this());
 }
 
+unsigned int SctpTransport::maxStream() const {
+	unsigned int streamsCount = mNegotiatedStreamsCount.value_or(MAX_SCTP_STREAMS_COUNT);
+	return streamsCount > 0 ? streamsCount - 1 : 0;
+}
+
 void SctpTransport::incoming(message_ptr message) {
 	// There could be a race condition here where we receive the remote INIT before the local one is
 	// sent, which would result in the connection being aborted. Therefore, we need to wait for data
@@ -804,8 +812,14 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
 
 	switch (type) {
 	case SCTP_ASSOC_CHANGE: {
-		const struct sctp_assoc_change &assoc_change = notify->sn_assoc_change;
-		if (assoc_change.sac_state == SCTP_COMM_UP) {
+		PLOG_VERBOSE << "SCTP association change event";
+		const struct sctp_assoc_change &sac = notify->sn_assoc_change;
+		if (sac.sac_state == SCTP_COMM_UP) {
+			PLOG_DEBUG << "SCTP negotiated streams: incoming=" << sac.sac_inbound_streams
+			           << ", outgoing=" << sac.sac_outbound_streams;
+			mNegotiatedStreamsCount.emplace(
+			    std::min(sac.sac_inbound_streams, sac.sac_outbound_streams));
+
 			PLOG_INFO << "SCTP connected";
 			changeState(State::Connected);
 		} else {
@@ -822,7 +836,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
 	}
 
 	case SCTP_SENDER_DRY_EVENT: {
-		PLOG_VERBOSE << "SCTP dry event";
+		PLOG_VERBOSE << "SCTP sender dry event";
 		// It should not be necessary since the send callback should have been called already,
 		// but to be sure, let's try to send now.
 		flush();

+ 3 - 0
src/impl/sctptransport.hpp

@@ -59,6 +59,8 @@ public:
 	bool flush();
 	void closeStream(unsigned int stream);
 
+	unsigned int maxStream() const;
+
 	void onBufferedAmount(amount_callback callback) {
 		mBufferedAmountCallback = std::move(callback);
 	}
@@ -106,6 +108,7 @@ private:
 
 	const Ports mPorts;
 	struct socket *mSock;
+	std::optional<uint16_t> mNegotiatedStreamsCount;
 
 	Processor mProcessor;
 	std::atomic<int> mPendingRecvCount = 0;

+ 3 - 3
src/peerconnection.cpp

@@ -259,6 +259,8 @@ optional<string> PeerConnection::remoteAddress() const {
 	return iceTransport ? iceTransport->getRemoteAddress() : nullopt;
 }
 
+uint16_t PeerConnection::maxDataChannelId() const { return impl()->maxDataChannelStream(); }
+
 shared_ptr<DataChannel> PeerConnection::createDataChannel(string label, DataChannelInit init) {
 	auto channelImpl = impl()->emplaceDataChannel(std::move(label), std::move(init));
 	auto channel = std::make_shared<DataChannel>(channelImpl);
@@ -319,9 +321,7 @@ void PeerConnection::onSignalingStateChange(std::function<void(SignalingState st
 	impl()->signalingStateChangeCallback = callback;
 }
 
-void PeerConnection::resetCallbacks() {
-	impl()->resetCallbacks();
-}
+void PeerConnection::resetCallbacks() { impl()->resetCallbacks(); }
 
 bool PeerConnection::getSelectedCandidatePair(Candidate *local, Candidate *remote) {
 	auto iceTransport = impl()->getIceTransport();