Browse Source

Changed to two SRTP sessions and introduced srtp_add_stream()

Paul-Louis Ageneau 5 years ago
parent
commit
5afbe10d01
2 changed files with 49 additions and 44 deletions
  1. 47 42
      src/dtlssrtptransport.cpp
  2. 2 2
      src/dtlssrtptransport.hpp

+ 47 - 42
src/dtlssrtptransport.cpp

@@ -42,25 +42,30 @@ DtlsSrtpTransport::DtlsSrtpTransport(std::shared_ptr<IceTransport> lower,
     : DtlsTransport(lower, certificate, std::move(verifierCallback),
                     std::move(stateChangeCallback)),
       mSrtpRecvCallback(std::move(srtpRecvCallback)) { // distinct from Transport recv callback
-#if USE_GNUTLS
-	PLOG_DEBUG << "Initializing DTLS-SRTP transport (GnuTLS)";
-#else
-	PLOG_DEBUG << "Initializing DTLS-SRTP transport (OpenSSL)";
-#endif
+
+	PLOG_DEBUG << "Initializing DTLS-SRTP transport";
+
+	if (srtp_err_status_t err = srtp_create(&mSrtpIn, nullptr)) {
+		throw std::runtime_error("SRTP create failed, status=" + to_string(static_cast<int>(err)));
+	}
+	if (srtp_err_status_t err = srtp_create(&mSrtpOut, nullptr)) {
+		srtp_dealloc(mSrtpIn);
+		throw std::runtime_error("SRTP create failed, status=" + to_string(static_cast<int>(err)));
+	}
 }
 
 DtlsSrtpTransport::~DtlsSrtpTransport() {
 	stop();
 
-	if (mCreated)
-		srtp_dealloc(mSrtp);
+	srtp_dealloc(mSrtpIn);
+	srtp_dealloc(mSrtpOut);
 }
 
 bool DtlsSrtpTransport::sendMedia(message_ptr message) {
 	if (!message)
 		return false;
 
-	if (!mCreated) {
+	if (!mInitDone) {
 		PLOG_WARNING << "SRTP media sent before keys are derived";
 		return false;
 	}
@@ -71,7 +76,7 @@ bool DtlsSrtpTransport::sendMedia(message_ptr message) {
 	// srtp_protect() assumes that it can write SRTP_MAX_TRAILER_LEN (for the authentication tag)
 	// into the location in memory immediately following the RTP packet.
 	message->resize(size + SRTP_MAX_TRAILER_LEN);
-	if (srtp_err_status_t err = srtp_protect(mSrtp, message->data(), &size)) {
+	if (srtp_err_status_t err = srtp_protect(mSrtpOut, message->data(), &size)) {
 		if (err == srtp_err_status_replay_fail)
 			throw std::runtime_error("SRTP packet is a replay");
 		else
@@ -85,7 +90,7 @@ bool DtlsSrtpTransport::sendMedia(message_ptr message) {
 }
 
 void DtlsSrtpTransport::incoming(message_ptr message) {
-	if (!mCreated) {
+	if (!mInitDone) {
 		// Bypas
 		DtlsTransport::incoming(message);
 		return;
@@ -107,7 +112,7 @@ void DtlsSrtpTransport::incoming(message_ptr message) {
 	} else if (value >= 20 && value <= 63) {
 		PLOG_VERBOSE << "Incoming SRTP packet, size=" << size;
 
-		if (srtp_err_status_t err = srtp_unprotect(mSrtp, message->data(), &size)) {
+		if (srtp_err_status_t err = srtp_unprotect(mSrtpIn, message->data(), &size)) {
 			if (err == srtp_err_status_replay_fail)
 				PLOG_WARNING << "Incoming SRTP packet is a replay";
 			else
@@ -119,16 +124,17 @@ void DtlsSrtpTransport::incoming(message_ptr message) {
 		mSrtpRecvCallback(message);
 
 	} else {
-		PLOG_WARNING << "Unknown packet type, value=" << value << ", size=" << size;
+		PLOG_WARNING << "Unknown packet type, value=" << unsigned(value) << ", size=" << size;
 	}
 }
 
 void DtlsSrtpTransport::postCreation() {
-	PLOG_DEBUG << "Setting SRTP profile";
 #if USE_GNUTLS
+	PLOG_DEBUG << "Setting SRTP profile (GnuTLS)";
 	gnutls::check(gnutls_srtp_set_profile(mSession, GNUTLS_SRTP_AES128_CM_HMAC_SHA1_80),
 	              "Failed to set SRTP profile");
 #else
+	PLOG_DEBUG << "Setting SRTP profile (OpenSSL)";
 	// returns 0 on success, 1 on error
 	if (SSL_set_tlsext_use_srtp(mSsl, "SRTP_AES128_CM_SHA1_80"), "Failed to set SRTP profile")
 		throw std::runtime_error("Failed to set SRTP profile: " + openssl::error_string(ERR_get_error()));
@@ -136,28 +142,16 @@ void DtlsSrtpTransport::postCreation() {
 }
 
 void DtlsSrtpTransport::postHandshake() {
-	if (mCreated)
+	if (mInitDone)
 		return;
 
-	PLOG_INFO << "Deriving SRTP keying material";
-
-	srtp_policy_t inbound = {};
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtp);
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtcp);
-	inbound.ssrc.type = ssrc_any_inbound;
-	inbound.ssrc.value = 0;
-
-	srtp_policy_t outbound = {};
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtp);
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtcp);
-	outbound.ssrc.type = ssrc_any_outbound;
-	outbound.ssrc.value = 0;
-
 	const size_t materialLen = SRTP_AES_ICM_128_KEY_LEN_WSALT * 2;
 	unsigned char material[materialLen];
 	const unsigned char *clientKey, *clientSalt, *serverKey, *serverSalt;
 
 #if USE_GNUTLS
+	PLOG_INFO << "Deriving SRTP keying material (GnuTLS)";
+
 	gnutls_datum_t clientKeyDatum, clientSaltDatum, serverKeyDatum, serverSaltDatum;
 	gnutls::check(gnutls_srtp_get_keys(mSession, material, materialLen, &clientKeyDatum,
 	                                   &clientSaltDatum, &serverKeyDatum, &serverSaltDatum),
@@ -180,8 +174,10 @@ void DtlsSrtpTransport::postHandshake() {
 	serverKey = reinterpret_cast<const unsigned char *>(serverKeyDatum.data);
 	serverSalt = reinterpret_cast<const unsigned char *>(serverSaltDatum.data);
 #else
-	// This provides the client write master key, the server write master key, the client write
-	// master salt and the server write master salt in that order.
+	PLOG_INFO << "Deriving SRTP keying material (OpenSSL)";
+
+	// The extractor provides the client write master key, the server write master key, the client
+	// write master salt and the server write master salt in that order.
 	const string label = "EXTRACTOR-dtls_srtp";
 
 	// returns 1 on success, 0 or -1 on failure (OpenSSL API is a complete mess...)
@@ -205,22 +201,31 @@ void DtlsSrtpTransport::postHandshake() {
 	std::memcpy(serverSessionKey, serverKey, SRTP_AES_128_KEY_LEN);
 	std::memcpy(serverSessionKey + SRTP_AES_128_KEY_LEN, serverSalt, SRTP_SALT_LEN);
 
-	if (mIsClient) {
-		inbound.key = serverSessionKey;
-		outbound.key = clientSessionKey;
-	} else {
-		inbound.key = clientSessionKey;
-		outbound.key = serverSessionKey;
-	}
+	srtp_policy_t inbound = {};
+	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtp);
+	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtcp);
+	inbound.ssrc.type = ssrc_any_inbound;
+	inbound.ssrc.value = 0;
+	inbound.key = mIsClient ? serverSessionKey : clientSessionKey;
+	inbound.next = nullptr;
+
+	if (srtp_err_status_t err = srtp_add_stream(mSrtpIn, &inbound))
+		throw std::runtime_error("SRTP add inbound stream failed, status=" +
+		                         to_string(static_cast<int>(err)));
 
-	srtp_policy_t *policies = &inbound;
-	inbound.next = &outbound;
+	srtp_policy_t outbound = {};
+	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtp);
+	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtcp);
+	outbound.ssrc.type = ssrc_any_outbound;
+	outbound.ssrc.value = 0;
+	outbound.key = mIsClient ? clientSessionKey : serverSessionKey;
 	outbound.next = nullptr;
 
-	if (srtp_err_status_t err = srtp_create(&mSrtp, policies))
-		throw std::runtime_error("SRTP create failed, status=" + to_string(static_cast<int>(err)));
+	if (srtp_err_status_t err = srtp_add_stream(mSrtpOut, &outbound))
+		throw std::runtime_error("SRTP add outbound stream failed, status=" +
+		                         to_string(static_cast<int>(err)));
 
-	mCreated = true;
+	mInitDone = true;
 }
 
 } // namespace rtc

+ 2 - 2
src/dtlssrtptransport.hpp

@@ -47,8 +47,8 @@ private:
 
 	message_callback mSrtpRecvCallback;
 
-	srtp_t mSrtp;
-	bool mCreated = false;
+	srtp_t mSrtpIn, mSrtpOut;
+	bool mInitDone = false;
 };
 
 } // namespace rtc