Browse Source

Merge pull request #604 from paullouisageneau/srtp-in-thread

Perform SRTP processing in DTLS thread
Paul-Louis Ageneau 3 years ago
parent
commit
b67065a77a

+ 71 - 65
src/impl/dtlssrtptransport.cpp

@@ -146,16 +146,73 @@ bool DtlsSrtpTransport::sendMedia(message_ptr message) {
 	return Transport::outgoing(message); // bypass DTLS DSCP marking
 	return Transport::outgoing(message); // bypass DTLS DSCP marking
 }
 }
 
 
-void DtlsSrtpTransport::incoming(message_ptr message) {
-	if (!mInitDone) {
-		// Bypas
-		DtlsTransport::incoming(message);
+void DtlsSrtpTransport::recvMedia(message_ptr message) {
+	// The RTP header has a minimum size of 12 bytes
+	// An RTCP packet can have a minimum size of 8 bytes
+	int size = int(message->size());
+	if (size < 8) {
+		COUNTER_MEDIA_TRUNCATED++;
+		PLOG_VERBOSE << "Incoming SRTP/SRTCP packet too short, size=" << size;
 		return;
 		return;
 	}
 	}
 
 
-	int size = int(message->size());
-	if (size == 0)
-		return;
+	uint8_t value2 = to_integer<uint8_t>(*(message->begin() + 1)) & 0x7F;
+	PLOG_VERBOSE << "Demultiplexing SRTCP and SRTP with RTP payload type, value="
+	             << unsigned(value2);
+
+	// See RFC 5761 reference above
+	if (value2 >= 64 && value2 <= 95) { // Range 64-95 (inclusive) MUST be RTCP
+		PLOG_VERBOSE << "Incoming SRTCP packet, size=" << size;
+		if (srtp_err_status_t err = srtp_unprotect_rtcp(mSrtpIn, message->data(), &size)) {
+			if (err == srtp_err_status_replay_fail) {
+				PLOG_VERBOSE << "Incoming SRTCP packet is a replay";
+				COUNTER_SRTCP_REPLAY++;
+			} else if (err == srtp_err_status_auth_fail) {
+				PLOG_VERBOSE << "Incoming SRTCP packet failed authentication check";
+				COUNTER_SRTCP_AUTH_FAIL++;
+			} else {
+				PLOG_VERBOSE << "SRTCP unprotect error, status=" << err;
+				COUNTER_SRTCP_FAIL++;
+			}
+
+			return;
+		}
+		PLOG_VERBOSE << "Unprotected SRTCP packet, size=" << size;
+		message->type = Message::Control;
+		message->stream = reinterpret_cast<RtcpSr *>(message->data())->senderSSRC();
+
+	} else {
+		PLOG_VERBOSE << "Incoming SRTP packet, size=" << size;
+		if (srtp_err_status_t err = srtp_unprotect(mSrtpIn, message->data(), &size)) {
+			if (err == srtp_err_status_replay_fail) {
+				PLOG_VERBOSE << "Incoming SRTP packet is a replay";
+				COUNTER_SRTP_REPLAY++;
+			} else if (err == srtp_err_status_auth_fail) {
+				PLOG_VERBOSE << "Incoming SRTP packet failed authentication check";
+				COUNTER_SRTP_AUTH_FAIL++;
+			} else {
+				PLOG_VERBOSE << "SRTP unprotect error, status=" << err;
+				COUNTER_SRTP_FAIL++;
+			}
+			return;
+		}
+		PLOG_VERBOSE << "Unprotected SRTP packet, size=" << size;
+		message->type = Message::Binary;
+		message->stream = reinterpret_cast<RtpHeader *>(message->data())->ssrc();
+	}
+
+	message->resize(size);
+	mSrtpRecvCallback(message);
+}
+
+bool DtlsSrtpTransport::demuxMessage(message_ptr message) {
+	if (!mInitDone) {
+		// Bypass
+		return false;
+	}
+
+	if (message->size() == 0)
+		return false;
 
 
 	// RFC 5764 5.1.2. Reception
 	// RFC 5764 5.1.2. Reception
 	// https://www.rfc-editor.org/rfc/rfc5764.html#section-5.1.2
 	// https://www.rfc-editor.org/rfc/rfc5764.html#section-5.1.2
@@ -167,69 +224,18 @@ void DtlsSrtpTransport::incoming(message_ptr message) {
 	             << unsigned(value1);
 	             << unsigned(value1);
 
 
 	if (value1 >= 20 && value1 <= 63) {
 	if (value1 >= 20 && value1 <= 63) {
-		PLOG_VERBOSE << "Incoming DTLS packet, size=" << size;
-		DtlsTransport::incoming(message);
+		PLOG_VERBOSE << "Incoming DTLS packet, size=" << message->size();
+		return false;
 
 
 	} else if (value1 >= 128 && value1 <= 191) {
 	} else if (value1 >= 128 && value1 <= 191) {
-		// The RTP header has a minimum size of 12 bytes
-		// An RTCP packet can have a minimum size of 8 bytes
-		if (size < 8) {
-			COUNTER_MEDIA_TRUNCATED++;
-			PLOG_VERBOSE << "Incoming SRTP/SRTCP packet too short, size=" << size;
-			return;
-		}
-
-		uint8_t value2 = to_integer<uint8_t>(*(message->begin() + 1)) & 0x7F;
-		PLOG_VERBOSE << "Demultiplexing SRTCP and SRTP with RTP payload type, value="
-		             << unsigned(value2);
-
-		// See RFC 5761 reference above
-		if (value2 >= 64 && value2 <= 95) { // Range 64-95 (inclusive) MUST be RTCP
-			PLOG_VERBOSE << "Incoming SRTCP packet, size=" << size;
-			if (srtp_err_status_t err = srtp_unprotect_rtcp(mSrtpIn, message->data(), &size)) {
-				if (err == srtp_err_status_replay_fail) {
-					PLOG_VERBOSE << "Incoming SRTCP packet is a replay";
-					COUNTER_SRTCP_REPLAY++;
-				} else if (err == srtp_err_status_auth_fail) {
-					PLOG_VERBOSE << "Incoming SRTCP packet failed authentication check";
-					COUNTER_SRTCP_AUTH_FAIL++;
-				} else {
-					PLOG_VERBOSE << "SRTCP unprotect error, status=" << err;
-					COUNTER_SRTCP_FAIL++;
-				}
-
-				return;
-			}
-			PLOG_VERBOSE << "Unprotected SRTCP packet, size=" << size;
-			message->type = Message::Control;
-			message->stream = reinterpret_cast<RtcpSr *>(message->data())->senderSSRC();
-
-		} else {
-			PLOG_VERBOSE << "Incoming SRTP packet, size=" << size;
-			if (srtp_err_status_t err = srtp_unprotect(mSrtpIn, message->data(), &size)) {
-				if (err == srtp_err_status_replay_fail) {
-					PLOG_VERBOSE << "Incoming SRTP packet is a replay";
-					COUNTER_SRTP_REPLAY++;
-				} else if (err == srtp_err_status_auth_fail) {
-					PLOG_VERBOSE << "Incoming SRTP packet failed authentication check";
-					COUNTER_SRTP_AUTH_FAIL++;
-				} else {
-					PLOG_VERBOSE << "SRTP unprotect error, status=" << err;
-					COUNTER_SRTP_FAIL++;
-				}
-				return;
-			}
-			PLOG_VERBOSE << "Unprotected SRTP packet, size=" << size;
-			message->type = Message::Binary;
-			message->stream = reinterpret_cast<RtpHeader *>(message->data())->ssrc();
-		}
-
-		message->resize(size);
-		mSrtpRecvCallback(message);
+		recvMedia(std::move(message));
+		return true;
 
 
 	} else {
 	} else {
 		COUNTER_UNKNOWN_PACKET_TYPE++;
 		COUNTER_UNKNOWN_PACKET_TYPE++;
-		PLOG_VERBOSE << "Unknown packet type, value=" << unsigned(value1) << ", size=" << size;
+		PLOG_VERBOSE << "Unknown packet type, value=" << unsigned(value1)
+		             << ", size=" << message->size();
+		return true;
 	}
 	}
 }
 }
 
 

+ 2 - 1
src/impl/dtlssrtptransport.hpp

@@ -47,7 +47,8 @@ public:
 	bool sendMedia(message_ptr message);
 	bool sendMedia(message_ptr message);
 
 
 private:
 private:
-	void incoming(message_ptr message) override;
+	void recvMedia(message_ptr message);
+	bool demuxMessage(message_ptr message) override;
 	void postHandshake() override;
 	void postHandshake() override;
 
 
 	message_callback mSrtpRecvCallback;
 	message_callback mSrtpRecvCallback;

+ 17 - 1
src/impl/dtlstransport.cpp

@@ -157,6 +157,11 @@ bool DtlsTransport::outgoing(message_ptr message) {
 	return Transport::outgoing(std::move(message));
 	return Transport::outgoing(std::move(message));
 }
 }
 
 
+bool DtlsTransport::demuxMessage(message_ptr) {
+	// Dummy
+	return false;
+}
+
 void DtlsTransport::postHandshake() {
 void DtlsTransport::postHandshake() {
 	// Dummy
 	// Dummy
 }
 }
@@ -298,8 +303,11 @@ ssize_t DtlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *dat
 ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
 ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
 	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
 	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
 	try {
 	try {
-		if (auto next = t->mIncomingQueue.pop()) {
+		while (auto next = t->mIncomingQueue.pop()) {
 			message_ptr message = std::move(*next);
 			message_ptr message = std::move(*next);
+			if (t->demuxMessage(message))
+				continue;
+
 			ssize_t len = std::min(maxlen, message->size());
 			ssize_t len = std::min(maxlen, message->size());
 			std::memcpy(data, message->data(), len);
 			std::memcpy(data, message->data(), len);
 			gnutls_transport_set_errno(t->mSession, 0);
 			gnutls_transport_set_errno(t->mSession, 0);
@@ -504,6 +512,11 @@ bool DtlsTransport::outgoing(message_ptr message) {
 	return Transport::outgoing(std::move(message));
 	return Transport::outgoing(std::move(message));
 }
 }
 
 
+bool DtlsTransport::demuxMessage(message_ptr) {
+	// Dummy
+	return false;
+}
+
 void DtlsTransport::postHandshake() {
 void DtlsTransport::postHandshake() {
 	// Dummy
 	// Dummy
 }
 }
@@ -526,6 +539,9 @@ void DtlsTransport::runRecvLoop() {
 			// Process pending messages
 			// Process pending messages
 			while (auto next = mIncomingQueue.tryPop()) {
 			while (auto next = mIncomingQueue.tryPop()) {
 				message_ptr message = std::move(*next);
 				message_ptr message = std::move(*next);
+				if (demuxMessage(message))
+					continue;
+
 				BIO_write(mInBio, message->data(), int(message->size()));
 				BIO_write(mInBio, message->data(), int(message->size()));
 
 
 				if (state() == State::Connecting) {
 				if (state() == State::Connecting) {

+ 1 - 0
src/impl/dtlstransport.hpp

@@ -55,6 +55,7 @@ public:
 protected:
 protected:
 	virtual void incoming(message_ptr message) override;
 	virtual void incoming(message_ptr message) override;
 	virtual bool outgoing(message_ptr message) override;
 	virtual bool outgoing(message_ptr message) override;
+	virtual bool demuxMessage(message_ptr message);
 	virtual void postHandshake();
 	virtual void postHandshake();
 	void runRecvLoop();
 	void runRecvLoop();