Переглянути джерело

Clean up DtlsSrtpTransport

Paul-Louis Ageneau 2 роки тому
батько
коміт
f622d7d2e2
2 змінених файлів з 66 додано та 56 видалено
  1. 58 50
      src/impl/dtlssrtptransport.cpp
  2. 8 6
      src/impl/dtlssrtptransport.hpp

+ 58 - 50
src/impl/dtlssrtptransport.cpp

@@ -218,28 +218,18 @@ bool DtlsSrtpTransport::demuxMessage(message_ptr message) {
 	}
 }
 
-EncryptionParams DtlsSrtpTransport::getEncryptionParams(string_view suite) {
-	if (suite == "SRTP_AES128_CM_SHA1_80")
-		return {SRTP_AES_128_KEY_LEN, SRTP_SALT_LEN, srtp_profile_aes128_cm_sha1_80};
-	if (suite == "SRTP_AES128_CM_SHA1_32")
-		return {SRTP_AES_128_KEY_LEN, SRTP_SALT_LEN, srtp_profile_aes128_cm_sha1_32};
-	if (suite == "SRTP_AEAD_AES_128_GCM")
-		return {SRTP_AES_128_KEY_LEN, SRTP_AEAD_SALT_LEN, srtp_profile_aead_aes_128_gcm};
-	if (suite == "SRTP_AEAD_AES_256_GCM")
-		return {SRTP_AES_256_KEY_LEN, SRTP_AEAD_SALT_LEN, srtp_profile_aead_aes_256_gcm};
-	throw std::logic_error("Unexpected SRTP suite name: " + std::string(suite));
-}
-
 void DtlsSrtpTransport::postHandshake() {
 	if (mInitDone)
 		return;
-	const unsigned char *clientKey, *clientSalt, *serverKey, *serverSalt;
+
 #if USE_GNUTLS
 	PLOG_INFO << "Deriving SRTP keying material (GnuTLS)";
-	unsigned int keySize = SRTP_AES_128_KEY_LEN;
-	unsigned int saltSize = SRTP_SALT_LEN;
-	auto srtpProfile = srtp_profile_aes128_cm_sha1_80;
-	auto keySizeWithSalt = SRTP_AES_ICM_128_KEY_LEN_WSALT;
+
+	const srtp_profile_t srtpProfile = srtp_profile_aes128_cm_sha1_80;
+	const size_t keySize = SRTP_AES_128_KEY_LEN;
+	const size_t saltSize = SRTP_SALT_LEN;
+	const size_t keySizeWithSalt = SRTP_AES_ICM_128_KEY_LEN_WSALT;
+
 	const size_t materialLen = keySizeWithSalt * 2;
 	std::vector<unsigned char> material(materialLen);
 	gnutls_datum_t clientKeyDatum, clientSaltDatum, serverKeyDatum, serverSaltDatum;
@@ -258,59 +248,61 @@ void DtlsSrtpTransport::postHandshake() {
 	if (serverSaltDatum.size != saltSize)
 		throw std::logic_error("Unexpected SRTP salt size: " + to_string(serverSaltDatum.size));
 
-	clientKey = reinterpret_cast<const unsigned char *>(clientKeyDatum.data);
-	clientSalt = reinterpret_cast<const unsigned char *>(clientSaltDatum.data);
+	const unsigned char *clientKey = reinterpret_cast<const unsigned char *>(clientKeyDatum.data);
+	const unsigned char *clientSalt = reinterpret_cast<const unsigned char *>(clientSaltDatum.data);
+	const unsigned char *serverKey = reinterpret_cast<const unsigned char *>(serverKeyDatum.data);
+	const unsigned char *serverSalt = reinterpret_cast<const unsigned char *>(serverSaltDatum.data);
 
-	serverKey = reinterpret_cast<const unsigned char *>(serverKeyDatum.data);
-	serverSalt = reinterpret_cast<const unsigned char *>(serverSaltDatum.data);
 #elif USE_MBEDTLS
 	PLOG_INFO << "Deriving SRTP keying material (Mbed TLS)";
-	unsigned int keySize = SRTP_AES_128_KEY_LEN;
-	unsigned int saltSize = SRTP_SALT_LEN;
-	auto srtpProfile = srtp_profile_aes128_cm_sha1_80;
-	auto keySizeWithSalt = SRTP_AES_ICM_128_KEY_LEN_WSALT;
-	mbedtls_dtls_srtp_info srtpInfo;
 
+	mbedtls_dtls_srtp_info srtpInfo;
 	mbedtls_ssl_get_dtls_srtp_negotiation_result(&mSsl, &srtpInfo);
-	if (srtpInfo.private_chosen_dtls_srtp_profile != MBEDTLS_TLS_SRTP_AES128_CM_HMAC_SHA1_80) {
+	if (srtpInfo.private_chosen_dtls_srtp_profile != MBEDTLS_TLS_SRTP_AES128_CM_HMAC_SHA1_80)
 		throw std::runtime_error("Failed to get SRTP profile");
-	}
 
-	const size_t materialLen = keySizeWithSalt * 2;
-	std::vector<unsigned char> material(materialLen);
+	const srtp_profile_t srtpProfile = srtp_profile_aes128_cm_sha1_80;
+	const size_t keySize = SRTP_AES_128_KEY_LEN;
+	const size_t saltSize = SRTP_SALT_LEN;
+	const size_t keySizeWithSalt = SRTP_AES_ICM_128_KEY_LEN_WSALT;
+
+	if (mTlsProfile == MBEDTLS_SSL_TLS_PRF_NONE)
+		throw std::logic_error("TLS PRF type is not set");
+
 	// The extractor provides the client write master key, the server write master key, the client
 	// write master salt and the server write master salt in that order.
 	const string label = "EXTRACTOR-dtls_srtp";
-
-	if (mTlsProfile == MBEDTLS_SSL_TLS_PRF_NONE) {
-		throw std::logic_error("Failed to get SRTP profile");
-	}
+	const size_t materialLen = keySizeWithSalt * 2;
+	std::vector<unsigned char> material(materialLen);
 
 	if (mbedtls_ssl_tls_prf(mTlsProfile, reinterpret_cast<const unsigned char *>(mMasterSecret), 48,
 	                        label.c_str(), reinterpret_cast<const unsigned char *>(mRandBytes), 64,
-	                        material.data(), materialLen) != 0) {
+	                        material.data(), materialLen) != 0)
 		throw std::runtime_error("Failed to derive SRTP keys");
-	}
 
 	// Order is client key, server key, client salt, and server salt
-	clientKey = material.data();
-	serverKey = clientKey + keySize;
-	clientSalt = serverKey + keySize;
-	serverSalt = clientSalt + saltSize;
-#else
+	const unsigned char *clientKey = material.data();
+	const unsigned char *serverKey = clientKey + keySize;
+	const unsigned char *clientSalt = serverKey + keySize;
+	const unsigned char *serverSalt = clientSalt + saltSize;
+
+#else // OpenSSL
 	PLOG_INFO << "Deriving SRTP keying material (OpenSSL)";
 	auto profile = SSL_get_selected_srtp_profile(mSsl);
 	if (!profile)
 		throw std::runtime_error("Failed to get SRTP profile: " +
 		                         openssl::error_string(ERR_get_error()));
-	PLOG_DEBUG << "srtp profile used is: " << profile->name;
-	auto [keySize, saltSize, srtpProfile] = getEncryptionParams(profile->name);
-	auto keySizeWithSalt = keySize + saltSize;
-	const size_t materialLen = keySizeWithSalt * 2;
-	std::vector<unsigned char> material(materialLen);
+
+	PLOG_DEBUG << "SRTP profile is: " << profile->name;
+
+	const auto [srtpProfile, keySize, saltSize] = getProfileParamsFromName(profile->name);
+	const size_t keySizeWithSalt = keySize + saltSize;
+
 	// The extractor provides the client write master key, the server write master key, the client
 	// write master salt and the server write master salt in that order.
 	const string label = "EXTRACTOR-dtls_srtp";
+	const size_t materialLen = keySizeWithSalt * 2;
+	std::vector<unsigned char> material(materialLen);
 
 	// returns 1 on success, 0 or -1 on failure (OpenSSL API is a complete mess...)
 	if (SSL_export_keying_material(mSsl, material.data(), materialLen, label.c_str(), label.size(),
@@ -319,11 +311,12 @@ void DtlsSrtpTransport::postHandshake() {
 		                         openssl::error_string(ERR_get_error()));
 
 	// Order is client key, server key, client salt, and server salt
-	clientKey = material.data();
-	serverKey = clientKey + keySize;
-	clientSalt = serverKey + keySize;
-	serverSalt = clientSalt + saltSize;
+	const unsigned char *clientKey = material.data();
+	const unsigned char *serverKey = clientKey + keySize;
+	const unsigned char *clientSalt = serverKey + keySize;
+	const unsigned char *serverSalt = clientSalt + saltSize;
 #endif
+
 	mClientSessionKey.resize(keySizeWithSalt);
 	mServerSessionKey.resize(keySizeWithSalt);
 	std::memcpy(mClientSessionKey.data(), clientKey, keySize);
@@ -362,6 +355,21 @@ void DtlsSrtpTransport::postHandshake() {
 	mInitDone = true;
 }
 
+#if !USE_GNUTLS && !USE_MBEDTLS
+ProfileParams DtlsSrtpTransport::getProfileParamsFromName(string_view name) {
+	if (name == "SRTP_AES128_CM_SHA1_80")
+		return {srtp_profile_aes128_cm_sha1_80, SRTP_AES_128_KEY_LEN, SRTP_SALT_LEN};
+	if (name == "SRTP_AES128_CM_SHA1_32")
+		return {srtp_profile_aes128_cm_sha1_32, SRTP_AES_128_KEY_LEN, SRTP_SALT_LEN};
+	if (name == "SRTP_AEAD_AES_128_GCM")
+		return {srtp_profile_aead_aes_128_gcm, SRTP_AES_128_KEY_LEN, SRTP_AEAD_SALT_LEN};
+	if (name == "SRTP_AEAD_AES_256_GCM")
+		return {srtp_profile_aead_aes_256_gcm, SRTP_AES_256_KEY_LEN, SRTP_AEAD_SALT_LEN};
+
+	throw std::logic_error("Unknown SRTP profile name: " + std::string(name));
+}
+#endif
+
 } // namespace rtc::impl
 
 #endif

+ 8 - 6
src/impl/dtlssrtptransport.hpp

@@ -24,10 +24,10 @@
 
 namespace rtc::impl {
 
-struct EncryptionParams {
-	unsigned int keySize;
-	unsigned int saltSize;
+struct ProfileParams {
 	srtp_profile_t srtpProfile;
+	size_t keySize;
+	size_t saltSize;
 };
 
 class DtlsSrtpTransport final : public DtlsTransport {
@@ -46,11 +46,13 @@ private:
 	void recvMedia(message_ptr message);
 	bool demuxMessage(message_ptr message) override;
 	void postHandshake() override;
-	EncryptionParams getEncryptionParams(string_view suite);
-	message_callback mSrtpRecvCallback;
 
-	srtp_t mSrtpIn, mSrtpOut;
+#if !USE_GNUTLS && !USE_MBEDTLS
+	ProfileParams getProfileParamsFromName(string_view name);
+#endif
 
+	message_callback mSrtpRecvCallback;
+	srtp_t mSrtpIn, mSrtpOut;
 	std::atomic<bool> mInitDone = false;
 	std::vector<unsigned char> mClientSessionKey;
 	std::vector<unsigned char> mServerSessionKey;