Rafi Wiener 2 лет назад
Родитель
Сommit
0b98d7ac20
3 измененных файлов с 65 добавлено и 36 удалено
  1. 50 29
      src/impl/dtlssrtptransport.cpp
  2. 9 3
      src/impl/dtlssrtptransport.hpp
  3. 6 4
      src/impl/dtlstransport.cpp

+ 50 - 29
src/impl/dtlssrtptransport.cpp

@@ -218,33 +218,44 @@ bool DtlsSrtpTransport::demuxMessage(message_ptr message) {
 	}
 }
 
+EncryptionParams DtlsSrtpTransport::getEncryptionParams(string_view suite) {
+	if (suite == "SRTP_AES128_CM_SHA1_80")
+		return {SRTP_AES_128_KEY_LEN, SRTP_SALT_LEN, srtp_profile_aes128_cm_sha1_80};
+	if (suite == "SRTP_AES128_CM_SHA1_32")
+		return {SRTP_AES_128_KEY_LEN, SRTP_SALT_LEN, srtp_profile_aes128_cm_sha1_32};
+	if (suite == "SRTP_AEAD_AES_128_GCM")
+		return {SRTP_AES_128_KEY_LEN, SRTP_AEAD_SALT_LEN, srtp_profile_aead_aes_128_gcm};
+	if (suite == "SRTP_AEAD_AES_256_GCM")
+		return {SRTP_AES_256_KEY_LEN, SRTP_AEAD_SALT_LEN, srtp_profile_aead_aes_256_gcm};
+	throw std::logic_error("Unexpected SRTP suite name: " + std::string(suite));
+}
+
 void DtlsSrtpTransport::postHandshake() {
 	if (mInitDone)
 		return;
-
-	static_assert(SRTP_AES_ICM_128_KEY_LEN_WSALT == SRTP_AES_128_KEY_LEN + SRTP_SALT_LEN);
-
-	const size_t materialLen = SRTP_AES_ICM_128_KEY_LEN_WSALT * 2;
-	unsigned char material[materialLen];
 	const unsigned char *clientKey, *clientSalt, *serverKey, *serverSalt;
-
 #if USE_GNUTLS
 	PLOG_INFO << "Deriving SRTP keying material (GnuTLS)";
-
+	unsigned int keySize = SRTP_AES_128_KEY_LEN;
+	unsigned int saltSize = SRTP_SALT_LEN;
+	auto srtpProfile = srtp_profile_aes128_cm_sha1_80;
+	auto keySizeWithSalt = SRTP_AES_ICM_128_KEY_LEN_WSALT;
+	const size_t materialLen = keySizeWithSalt * 2;
+	std::vector<unsigned char> material(materialLen);
 	gnutls_datum_t clientKeyDatum, clientSaltDatum, serverKeyDatum, serverSaltDatum;
-	gnutls::check(gnutls_srtp_get_keys(mSession, material, materialLen, &clientKeyDatum,
+	gnutls::check(gnutls_srtp_get_keys(mSession, material.data(), materialLen, &clientKeyDatum,
 	                                   &clientSaltDatum, &serverKeyDatum, &serverSaltDatum),
 	              "Failed to derive SRTP keys");
 
-	if (clientKeyDatum.size != SRTP_AES_128_KEY_LEN)
+	if (clientKeyDatum.size != keySize)
 		throw std::logic_error("Unexpected SRTP master key length: " +
 		                       to_string(clientKeyDatum.size));
-	if (clientSaltDatum.size != SRTP_SALT_LEN)
+	if (clientSaltDatum.size != saltSize)
 		throw std::logic_error("Unexpected SRTP salt length: " + to_string(clientSaltDatum.size));
-	if (serverKeyDatum.size != SRTP_AES_128_KEY_LEN)
+	if (serverKeyDatum.size != keySize)
 		throw std::logic_error("Unexpected SRTP master key length: " +
 		                       to_string(serverKeyDatum.size));
-	if (serverSaltDatum.size != SRTP_SALT_LEN)
+	if (serverSaltDatum.size != saltSize)
 		throw std::logic_error("Unexpected SRTP salt size: " + to_string(serverSaltDatum.size));
 
 	clientKey = reinterpret_cast<const unsigned char *>(clientKeyDatum.data);
@@ -254,35 +265,45 @@ void DtlsSrtpTransport::postHandshake() {
 	serverSalt = reinterpret_cast<const unsigned char *>(serverSaltDatum.data);
 #else
 	PLOG_INFO << "Deriving SRTP keying material (OpenSSL)";
-
+	auto profile = SSL_get_selected_srtp_profile(mSsl);
+	if (!profile)
+		throw std::runtime_error("Failed to get SRTP profile: " +
+					openssl::error_string(ERR_get_error()));
+	PLOG_DEBUG << "srtp profile used is: " << profile->name;
+	auto [keySize, saltSize, srtpProfile] = getEncryptionParams(profile->name);
+	auto keySizeWithSalt = keySize + saltSize;
+	const size_t materialLen = keySizeWithSalt * 2;
+	std::vector<unsigned char> material(materialLen);
 	// The extractor provides the client write master key, the server write master key, the client
 	// write master salt and the server write master salt in that order.
 	const string label = "EXTRACTOR-dtls_srtp";
 
 	// returns 1 on success, 0 or -1 on failure (OpenSSL API is a complete mess...)
-	if (SSL_export_keying_material(mSsl, material, materialLen, label.c_str(), label.size(),
+	if (SSL_export_keying_material(mSsl, material.data(), materialLen, label.c_str(), label.size(),
 	                               nullptr, 0, 0) <= 0)
 		throw std::runtime_error("Failed to derive SRTP keys: " +
 		                         openssl::error_string(ERR_get_error()));
 
 	// Order is client key, server key, client salt, and server salt
-	clientKey = material;
-	serverKey = clientKey + SRTP_AES_128_KEY_LEN;
-	clientSalt = serverKey + SRTP_AES_128_KEY_LEN;
-	serverSalt = clientSalt + SRTP_SALT_LEN;
+	clientKey = material.data();
+	serverKey = clientKey + keySize;
+	clientSalt = serverKey + keySize;
+	serverSalt = clientSalt + saltSize;
 #endif
+	mClientSessionKey.resize(keySizeWithSalt);
+	mServerSessionKey.resize(keySizeWithSalt);
+	std::memcpy(mClientSessionKey.data(), clientKey, keySize);
+	std::memcpy(mClientSessionKey.data() + keySize, clientSalt, saltSize);
 
-	std::memcpy(mClientSessionKey, clientKey, SRTP_AES_128_KEY_LEN);
-	std::memcpy(mClientSessionKey + SRTP_AES_128_KEY_LEN, clientSalt, SRTP_SALT_LEN);
-
-	std::memcpy(mServerSessionKey, serverKey, SRTP_AES_128_KEY_LEN);
-	std::memcpy(mServerSessionKey + SRTP_AES_128_KEY_LEN, serverSalt, SRTP_SALT_LEN);
+	std::memcpy(mServerSessionKey.data(), serverKey, keySize);
+	std::memcpy(mServerSessionKey.data() + keySize, serverSalt, saltSize);
 
 	srtp_policy_t inbound = {};
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtp);
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtcp);
+	srtp_crypto_policy_set_from_profile_for_rtp(&inbound.rtp, srtpProfile);
+	srtp_crypto_policy_set_from_profile_for_rtcp(&inbound.rtp, srtpProfile);
 	inbound.ssrc.type = ssrc_any_inbound;
-	inbound.key = mIsClient ? mServerSessionKey : mClientSessionKey;
+	inbound.key = mIsClient ? mServerSessionKey.data() : mClientSessionKey.data();
+
 	inbound.window_size = 1024;
 	inbound.allow_repeat_tx = true;
 	inbound.next = nullptr;
@@ -292,10 +313,10 @@ void DtlsSrtpTransport::postHandshake() {
 		                         to_string(static_cast<int>(err)));
 
 	srtp_policy_t outbound = {};
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtp);
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtcp);
+	srtp_crypto_policy_set_from_profile_for_rtp(&outbound.rtp, srtpProfile);
+	srtp_crypto_policy_set_from_profile_for_rtcp(&outbound.rtp, srtpProfile);
 	outbound.ssrc.type = ssrc_any_outbound;
-	outbound.key = mIsClient ? mClientSessionKey : mServerSessionKey;
+	outbound.key = mIsClient ? mClientSessionKey.data() : mServerSessionKey.data();
 	outbound.window_size = 1024;
 	outbound.allow_repeat_tx = true;
 	outbound.next = nullptr;

+ 9 - 3
src/impl/dtlssrtptransport.hpp

@@ -24,6 +24,12 @@
 
 namespace rtc::impl {
 
+struct EncryptionParams {
+	unsigned int keySize;
+	unsigned int saltSize;
+	srtp_profile_t srtpProfile;
+};
+
 class DtlsSrtpTransport final : public DtlsTransport {
 public:
 	static void Init();
@@ -40,14 +46,14 @@ private:
 	void recvMedia(message_ptr message);
 	bool demuxMessage(message_ptr message) override;
 	void postHandshake() override;
-
+	EncryptionParams getEncryptionParams(string_view suite);
 	message_callback mSrtpRecvCallback;
 
 	srtp_t mSrtpIn, mSrtpOut;
 
 	std::atomic<bool> mInitDone = false;
-	unsigned char mClientSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
-	unsigned char mServerSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
+	std::vector<unsigned char> mClientSessionKey;
+	std::vector<unsigned char> mServerSessionKey;
 	std::mutex sendMutex;
 };
 

+ 6 - 4
src/impl/dtlstransport.cpp

@@ -437,10 +437,12 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 		// RFC 8827: The DTLS-SRTP protection profile SRTP_AES128_CM_HMAC_SHA1_80 MUST be supported
 		// See https://www.rfc-editor.org/rfc/rfc8827.html#section-6.5 Warning:
 		// SSL_set_tlsext_use_srtp() returns 0 on success and 1 on error
-		if (SSL_set_tlsext_use_srtp(mSsl, "SRTP_AES128_CM_SHA1_80"))
-			throw std::runtime_error("Failed to set SRTP profile: " +
-			                         openssl::error_string(ERR_get_error()));
-
+		// Try to use GCM suite
+		if (SSL_set_tlsext_use_srtp(mSsl, "SRTP_AEAD_AES_256_GCM:SRTP_AEAD_AES_128_GCM:SRTP_AES128_CM_SHA1_80")) {
+			if (SSL_set_tlsext_use_srtp(mSsl, "SRTP_AES128_CM_SHA1_80"))
+				throw std::runtime_error("Failed to set SRTP profile: " +
+							openssl::error_string(ERR_get_error()));
+		}
 	} catch (...) {
 		if (mSsl)
 			SSL_free(mSsl);