Browse Source

Merge pull request #108 from paullouisageneau/fix-srtp-transport

Fix optional SRTP transport
Paul-Louis Ageneau 5 years ago
parent
commit
be79c68540
5 changed files with 159 additions and 66 deletions
  1. 143 62
      src/dtlssrtptransport.cpp
  2. 4 3
      src/dtlssrtptransport.hpp
  3. 10 0
      src/dtlstransport.cpp
  4. 1 0
      src/dtlstransport.hpp
  5. 1 1
      src/peerconnection.cpp

+ 143 - 62
src/dtlssrtptransport.cpp

@@ -43,50 +43,86 @@ DtlsSrtpTransport::DtlsSrtpTransport(std::shared_ptr<IceTransport> lower,
                     std::move(stateChangeCallback)),
       mSrtpRecvCallback(std::move(srtpRecvCallback)) { // distinct from Transport recv callback
 
-	PLOG_DEBUG << "Initializing SRTP transport";
+	PLOG_DEBUG << "Initializing DTLS-SRTP transport";
 
-#if USE_GNUTLS
-	PLOG_DEBUG << "Initializing DTLS-SRTP transport (GnuTLS)";
-	gnutls::check(gnutls_srtp_set_profile(mSession, GNUTLS_SRTP_AES128_CM_HMAC_SHA1_80),
-	              "Failed to set SRTP profile");
-#else
-	PLOG_DEBUG << "Initializing DTLS-SRTP transport (OpenSSL)";
-	openssl::check(SSL_set_tlsext_use_srtp(mSsl, "SRTP_AES128_CM_SHA1_80"),
-	               "Failed to set SRTP profile");
-#endif
+	if (srtp_err_status_t err = srtp_create(&mSrtpIn, nullptr)) {
+		throw std::runtime_error("SRTP create failed, status=" + to_string(static_cast<int>(err)));
+	}
+	if (srtp_err_status_t err = srtp_create(&mSrtpOut, nullptr)) {
+		srtp_dealloc(mSrtpIn);
+		throw std::runtime_error("SRTP create failed, status=" + to_string(static_cast<int>(err)));
+	}
 }
 
 DtlsSrtpTransport::~DtlsSrtpTransport() {
 	stop();
 
-	if (mCreated)
-		srtp_dealloc(mSrtp);
+	srtp_dealloc(mSrtpIn);
+	srtp_dealloc(mSrtpOut);
 }
 
-bool DtlsSrtpTransport::send(message_ptr message) {
+bool DtlsSrtpTransport::sendMedia(message_ptr message) {
 	if (!message)
 		return false;
 
+	if (!mInitDone) {
+		PLOG_WARNING << "SRTP media sent before keys are derived";
+		return false;
+	}
+
 	int size = message->size();
 	PLOG_VERBOSE << "Send size=" << size;
 
-	// srtp_protect() assumes that it can write SRTP_MAX_TRAILER_LEN (for the authentication tag)
-	// into the location in memory immediately following the RTP packet.
+	// The RTP header has a minimum size of 12 bytes
+	if (size < 12)
+		throw std::runtime_error("RTP/RTCP packet too short");
+
+	// srtp_protect() and srtp_protect_rtcp() assume that they can write SRTP_MAX_TRAILER_LEN (for
+	// the authentication tag) into the location in memory immediately following the RTP packet.
 	message->resize(size + SRTP_MAX_TRAILER_LEN);
-	if (srtp_err_status_t err = srtp_protect(mSrtp, message->data(), &size)) {
-		if (err == srtp_err_status_replay_fail)
-			throw std::runtime_error("SRTP packet is a replay");
-		else
-			throw std::runtime_error("SRTP protect error, status=" +
-			                         to_string(static_cast<int>(err)));
+
+	uint8_t value2 = to_integer<uint8_t>(*(message->begin() + 1)) & 0x7F;
+	PLOG_VERBOSE << "Demultiplexing SRTCP and SRTP with RTP payload type, value="
+	             << unsigned(value2);
+
+	// RFC 5761 Multiplexing RTP and RTCP 4. Distinguishable RTP and RTCP Packets
+	// It is RECOMMENDED to follow the guidelines in the RTP/AVP profile for the choice of RTP
+	// payload type values, with the additional restriction that payload type values in the
+	// range 64-95 MUST NOT be used. Specifically, dynamic RTP payload types SHOULD be chosen in
+	// the range 96-127 where possible. Values below 64 MAY be used if that is insufficient
+	// [...]
+	if (value2 >= 64 && value2 <= 95) { // Range 64-95 (inclusive) MUST be RTCP
+		if (srtp_err_status_t err = srtp_protect_rtcp(mSrtpOut, message->data(), &size)) {
+			if (err == srtp_err_status_replay_fail)
+				throw std::runtime_error("SRTCP packet is a replay");
+			else
+				throw std::runtime_error("SRTCP protect error, status=" +
+				                         to_string(static_cast<int>(err)));
+		}
+		PLOG_VERBOSE << "Protected SRTCP packet, size=" << size;
+	} else {
+		if (srtp_err_status_t err = srtp_protect(mSrtpOut, message->data(), &size)) {
+			if (err == srtp_err_status_replay_fail)
+				throw std::runtime_error("SRTP packet is a replay");
+			else
+				throw std::runtime_error("SRTP protect error, status=" +
+				                         to_string(static_cast<int>(err)));
+		}
+		PLOG_VERBOSE << "Protected SRTP packet, size=" << size;
 	}
-	PLOG_VERBOSE << "Protected SRTP packet, size=" << size;
+
 	message->resize(size);
 	outgoing(message);
 	return true;
 }
 
 void DtlsSrtpTransport::incoming(message_ptr message) {
+	if (!mInitDone) {
+		// Bypas
+		DtlsTransport::incoming(message);
+		return;
+	}
+
 	int size = message->size();
 	if (size == 0)
 		return;
@@ -95,49 +131,80 @@ void DtlsSrtpTransport::incoming(message_ptr message) {
 	// The process for demultiplexing a packet is as follows. The receiver looks at the first byte
 	// of the packet. [...] If the value is in between 128 and 191 (inclusive), then the packet is
 	// RTP (or RTCP [...]). If the value is between 20 and 63 (inclusive), the packet is DTLS.
-	uint8_t value = to_integer<uint8_t>(*message->begin());
+	uint8_t value1 = to_integer<uint8_t>(*message->begin());
+	PLOG_VERBOSE << "Demultiplexing DTLS and SRTP/SRTCP with first byte, value="
+	             << unsigned(value1);
 
-	if (value >= 128 && value <= 192) {
+	if (value1 >= 20 && value1 <= 63) {
 		PLOG_VERBOSE << "Incoming DTLS packet, size=" << size;
 		DtlsTransport::incoming(message);
-	} else if (value >= 20 && value <= 64) {
-		PLOG_VERBOSE << "Incoming SRTP packet, size=" << size;
 
-		if (srtp_err_status_t err = srtp_unprotect(mSrtp, message->data(), &size)) {
-			if (err == srtp_err_status_replay_fail)
-				PLOG_WARNING << "Incoming SRTP packet is a replay";
-			else
-				PLOG_WARNING << "SRTP unprotect error, status=" << err;
+	} else if (value1 >= 128 && value1 <= 191) {
+		// The RTP header has a minimum size of 12 bytes
+		if (size < 12) {
+			PLOG_WARNING << "Incoming SRTP/SRTCP packet too short, size=" << size;
 			return;
 		}
-		PLOG_VERBOSE << "Unprotected SRTP packet, size=" << size;
+
+		uint8_t value2 = to_integer<uint8_t>(*(message->begin() + 1)) & 0x7F;
+		PLOG_VERBOSE << "Demultiplexing SRTCP and SRTP with RTP payload type, value="
+		             << unsigned(value2);
+
+		// See RFC 5761 reference above
+		if (value2 >= 64 && value2 <= 95) { // Range 64-95 (inclusive) MUST be RTCP
+			PLOG_VERBOSE << "Incoming SRTCP packet, size=" << size;
+			if (srtp_err_status_t err = srtp_unprotect_rtcp(mSrtpIn, message->data(), &size)) {
+				if (err == srtp_err_status_replay_fail)
+					PLOG_WARNING << "Incoming SRTCP packet is a replay";
+				else
+					PLOG_WARNING << "SRTCP unprotect error, status=" << err;
+				return;
+			}
+			PLOG_VERBOSE << "Unprotected SRTCP packet, size=" << size;
+		} else {
+			PLOG_VERBOSE << "Incoming SRTP packet, size=" << size;
+			if (srtp_err_status_t err = srtp_unprotect(mSrtpIn, message->data(), &size)) {
+				if (err == srtp_err_status_replay_fail)
+					PLOG_WARNING << "Incoming SRTP packet is a replay";
+				else
+					PLOG_WARNING << "SRTP unprotect error, status=" << err;
+				return;
+			}
+			PLOG_VERBOSE << "Unprotected SRTP packet, size=" << size;
+		}
+
 		message->resize(size);
 		mSrtpRecvCallback(message);
 
 	} else {
-		PLOG_WARNING << "Unknown packet type, value=" << value << ", size=" << size;
+		PLOG_WARNING << "Unknown packet type, value=" << unsigned(value1) << ", size=" << size;
 	}
 }
 
+void DtlsSrtpTransport::postCreation() {
+#if USE_GNUTLS
+	PLOG_DEBUG << "Setting SRTP profile (GnuTLS)";
+	gnutls::check(gnutls_srtp_set_profile(mSession, GNUTLS_SRTP_AES128_CM_HMAC_SHA1_80),
+	              "Failed to set SRTP profile");
+#else
+	PLOG_DEBUG << "Setting SRTP profile (OpenSSL)";
+	// returns 0 on success, 1 on error
+	if (SSL_set_tlsext_use_srtp(mSsl, "SRTP_AES128_CM_SHA1_80"), "Failed to set SRTP profile")
+		throw std::runtime_error("Failed to set SRTP profile: " + openssl::error_string(ERR_get_error()));
+#endif
+}
+
 void DtlsSrtpTransport::postHandshake() {
-	if (mCreated)
+	if (mInitDone)
 		return;
 
-	srtp_policy_t inbound = {};
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtp);
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtcp);
-	inbound.ssrc.type = ssrc_any_inbound;
-
-	srtp_policy_t outbound = {};
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtp);
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtcp);
-	outbound.ssrc.type = ssrc_any_outbound;
-
 	const size_t materialLen = SRTP_AES_ICM_128_KEY_LEN_WSALT * 2;
 	unsigned char material[materialLen];
 	const unsigned char *clientKey, *clientSalt, *serverKey, *serverSalt;
 
 #if USE_GNUTLS
+	PLOG_INFO << "Deriving SRTP keying material (GnuTLS)";
+
 	gnutls_datum_t clientKeyDatum, clientSaltDatum, serverKeyDatum, serverSaltDatum;
 	gnutls::check(gnutls_srtp_get_keys(mSession, material, materialLen, &clientKeyDatum,
 	                                   &clientSaltDatum, &serverKeyDatum, &serverSaltDatum),
@@ -160,18 +227,23 @@ void DtlsSrtpTransport::postHandshake() {
 	serverKey = reinterpret_cast<const unsigned char *>(serverKeyDatum.data);
 	serverSalt = reinterpret_cast<const unsigned char *>(serverSaltDatum.data);
 #else
-	// This provides the client write master key, the server write master key, the client write
-	// master salt and the server write master salt in that order.
+	PLOG_INFO << "Deriving SRTP keying material (OpenSSL)";
+
+	// The extractor provides the client write master key, the server write master key, the client
+	// write master salt and the server write master salt in that order.
 	const string label = "EXTRACTOR-dtls_srtp";
-	openssl::check(SSL_export_keying_material(mSsl, material, materialLen, label.c_str(),
-	                                          label.size(), nullptr, 0, 0),
-	               "Failed to derive SRTP keys");
+
+	// returns 1 on success, 0 or -1 on failure (OpenSSL API is a complete mess...)
+	if (SSL_export_keying_material(mSsl, material, materialLen, label.c_str(), label.size(),
+	                               nullptr, 0, 0) <= 0)
+		throw std::runtime_error("Failed to derive SRTP keys: " +
+		                         openssl::error_string(ERR_get_error()));
 
 	clientKey = material;
 	clientSalt = clientKey + SRTP_AES_128_KEY_LEN;
 
 	serverKey = material + SRTP_AES_ICM_128_KEY_LEN_WSALT;
-	serverSalt = serverSalt + SRTP_AES_128_KEY_LEN;
+	serverSalt = serverKey + SRTP_AES_128_KEY_LEN;
 #endif
 
 	unsigned char clientSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
@@ -182,22 +254,31 @@ void DtlsSrtpTransport::postHandshake() {
 	std::memcpy(serverSessionKey, serverKey, SRTP_AES_128_KEY_LEN);
 	std::memcpy(serverSessionKey + SRTP_AES_128_KEY_LEN, serverSalt, SRTP_SALT_LEN);
 
-	if (mIsClient) {
-		inbound.key = serverSessionKey;
-		outbound.key = clientSessionKey;
-	} else {
-		inbound.key = clientSessionKey;
-		outbound.key = serverSessionKey;
-	}
+	srtp_policy_t inbound = {};
+	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtp);
+	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtcp);
+	inbound.ssrc.type = ssrc_any_inbound;
+	inbound.ssrc.value = 0;
+	inbound.key = mIsClient ? serverSessionKey : clientSessionKey;
+	inbound.next = nullptr;
+
+	if (srtp_err_status_t err = srtp_add_stream(mSrtpIn, &inbound))
+		throw std::runtime_error("SRTP add inbound stream failed, status=" +
+		                         to_string(static_cast<int>(err)));
 
-	srtp_policy_t *policies = &inbound;
-	inbound.next = &outbound;
+	srtp_policy_t outbound = {};
+	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtp);
+	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtcp);
+	outbound.ssrc.type = ssrc_any_outbound;
+	outbound.ssrc.value = 0;
+	outbound.key = mIsClient ? clientSessionKey : serverSessionKey;
 	outbound.next = nullptr;
 
-	if (srtp_err_status_t err = srtp_create(&mSrtp, policies))
-		throw std::runtime_error("SRTP create failed, status=" + to_string(static_cast<int>(err)));
+	if (srtp_err_status_t err = srtp_add_stream(mSrtpOut, &outbound))
+		throw std::runtime_error("SRTP add outbound stream failed, status=" +
+		                         to_string(static_cast<int>(err)));
 
-	mCreated = true;
+	mInitDone = true;
 }
 
 } // namespace rtc

+ 4 - 3
src/dtlssrtptransport.hpp

@@ -38,16 +38,17 @@ public:
 	                  state_callback stateChangeCallback);
 	~DtlsSrtpTransport();
 
-	bool send(message_ptr message) override;
+	bool sendMedia(message_ptr message);
 
 private:
 	void incoming(message_ptr message) override;
+	void postCreation() override;
 	void postHandshake() override;
 
 	message_callback mSrtpRecvCallback;
 
-	srtp_t mSrtp;
-	bool mCreated = false;
+	srtp_t mSrtpIn, mSrtpOut;
+	bool mInitDone = false;
 };
 
 } // namespace rtc

+ 10 - 0
src/dtlstransport.cpp

@@ -85,6 +85,8 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 		gnutls_transport_set_pull_function(mSession, ReadCallback);
 		gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
 
+		postCreation();
+
 		mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
 		registerIncoming();
 
@@ -137,6 +139,10 @@ void DtlsTransport::incoming(message_ptr message) {
 	mIncomingQueue.push(message);
 }
 
+void DtlsTransport::postCreation() {
+	// Dummy
+}
+
 void DtlsTransport::postHandshake() {
 	// Dummy
 }
@@ -408,6 +414,10 @@ void DtlsTransport::incoming(message_ptr message) {
 	mIncomingQueue.push(message);
 }
 
+void DtlsTransport::postCreation() {
+	// Dummy
+}
+
 void DtlsTransport::postHandshake() {
 	// Dummy
 }

+ 1 - 0
src/dtlstransport.hpp

@@ -52,6 +52,7 @@ public:
 
 protected:
 	virtual void incoming(message_ptr message) override;
+	virtual void postCreation();
 	virtual void postHandshake();
 	void runRecvLoop();
 

+ 1 - 1
src/peerconnection.cpp

@@ -244,7 +244,7 @@ void PeerConnection::outgoingMedia(message_ptr message) {
 	if (!transport)
 		throw std::runtime_error("PeerConnection is not open");
 
-	std::dynamic_pointer_cast<DtlsSrtpTransport>(transport)->send(message);
+	std::dynamic_pointer_cast<DtlsSrtpTransport>(transport)->sendMedia(message);
 #else
 	PLOG_WARNING << "Ignoring sent media (not compiled with SRTP support)";
 #endif