Browse Source

SRTP has to be aware of every RTP stream. This commit makes that work.

Staz M 4 years ago
parent
commit
0a46aa2c6d

+ 1 - 0
include/rtc/description.hpp

@@ -131,6 +131,7 @@ public:
 
         void addSSRC(uint32_t ssrc, std::string name);
         bool hasSSRC(uint32_t ssrc);
+        std::vector<uint32_t> getSSRCs();
 
 		void setBitrate(int bitrate);
 		int getBitrate() const;

+ 10 - 0
src/description.cpp

@@ -706,6 +706,16 @@ void Description::Media::addRTPMap(const Description::Media::RTPMap& map) {
     mRtpMap.emplace(map.pt, map);
 }
 
+std::vector<uint32_t> Description::Media::getSSRCs() {
+    std::vector<uint32_t> vec;
+    for (auto &val : mAttributes) {
+        PLOG_DEBUG << val;
+        if (val.find("ssrc:") == 0) {
+            vec.emplace_back(std::stoul((std::string)val.substr(5, val.find(" "))));
+        }
+    }
+    return vec;
+}
 
 
 Description::Media::RTPMap::RTPMap(string_view mline) {

+ 51 - 25
src/dtlssrtptransport.cpp

@@ -262,41 +262,67 @@ void DtlsSrtpTransport::postHandshake() {
 	serverSalt = clientSalt + SRTP_SALT_LEN;
 #endif
 
-	unsigned char clientSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
 	std::memcpy(clientSessionKey, clientKey, SRTP_AES_128_KEY_LEN);
 	std::memcpy(clientSessionKey + SRTP_AES_128_KEY_LEN, clientSalt, SRTP_SALT_LEN);
 
-	unsigned char serverSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
 	std::memcpy(serverSessionKey, serverKey, SRTP_AES_128_KEY_LEN);
 	std::memcpy(serverSessionKey + SRTP_AES_128_KEY_LEN, serverSalt, SRTP_SALT_LEN);
 
-	srtp_policy_t inbound = {};
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtp);
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtcp);
-	inbound.ssrc.type = ssrc_any_inbound;
-	inbound.ssrc.value = 0;
-	inbound.key = mIsClient ? serverSessionKey : clientSessionKey;
-	inbound.next = nullptr;
-
-	if (srtp_err_status_t err = srtp_add_stream(mSrtpIn, &inbound))
-		throw std::runtime_error("SRTP add inbound stream failed, status=" +
-		                         to_string(static_cast<int>(err)));
-
-	srtp_policy_t outbound = {};
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtp);
-	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtcp);
-	outbound.ssrc.type = ssrc_any_outbound;
-	outbound.ssrc.value = 0;
-	outbound.key = mIsClient ? clientSessionKey : serverSessionKey;
-	outbound.next = nullptr;
-
-	if (srtp_err_status_t err = srtp_add_stream(mSrtpOut, &outbound))
-		throw std::runtime_error("SRTP add outbound stream failed, status=" +
-		                         to_string(static_cast<int>(err)));
+//	srtp_policy_t inbound = {};
+//	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtp);
+//	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtcp);
+//	inbound.ssrc.type = ssrc_any_inbound;
+//	inbound.ssrc.value = 0;
+//	inbound.key = mIsClient ? serverSessionKey : clientSessionKey;
+//	inbound.next = nullptr;
+//
+//	if (srtp_err_status_t err = srtp_add_stream(mSrtpIn, &inbound))
+//		throw std::runtime_error("SRTP add inbound stream failed, status=" +
+//		                         to_string(static_cast<int>(err)));
+//
+//	srtp_policy_t outbound = {};
+//	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtp);
+//	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtcp);
+//	outbound.ssrc.type = ssrc_any_outbound;
+//	outbound.ssrc.value = 0;
+//	outbound.key = mIsClient ? clientSessionKey : serverSessionKey;
+//	outbound.next = nullptr;
+//
+//	if (srtp_err_status_t err = srtp_add_stream(mSrtpOut, &outbound))
+//		throw std::runtime_error("SRTP add outbound stream failed, status=" +
+//		                         to_string(static_cast<int>(err)));
 
 	mInitDone = true;
 }
 
+void DtlsSrtpTransport::addSSRC(uint32_t ssrc) {
+    srtp_policy_t inbound = {};
+    srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtp);
+    srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtcp);
+    inbound.ssrc.type = ssrc_specific;
+    inbound.ssrc.value = ssrc;
+    inbound.key = mIsClient ? serverSessionKey : clientSessionKey;
+    inbound.next = nullptr;
+
+    if (srtp_err_status_t err = srtp_add_stream(mSrtpIn, &inbound))
+        throw std::runtime_error("SRTP add inbound stream failed, status=" +
+                                 to_string(static_cast<int>(err)));
+
+
+    srtp_policy_t outbound = {};
+    srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtp);
+    srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtcp);
+    outbound.ssrc.type = ssrc_specific;
+    outbound.ssrc.value = ssrc;
+    outbound.key = mIsClient ? clientSessionKey : serverSessionKey;
+    outbound.next = nullptr;
+
+    if (srtp_err_status_t err = srtp_add_stream(mSrtpOut, &outbound))
+        throw std::runtime_error("SRTP add outbound stream failed, status=" +
+                                 to_string(static_cast<int>(err)));
+}
+
+
 } // namespace rtc
 
 #endif

+ 4 - 0
src/dtlssrtptransport.hpp

@@ -39,6 +39,7 @@ public:
 	~DtlsSrtpTransport();
 
 	bool sendMedia(message_ptr message);
+    void addSSRC(uint32_t ssrc);
 
 private:
 	void incoming(message_ptr message) override;
@@ -48,6 +49,9 @@ private:
 
 	srtp_t mSrtpIn, mSrtpOut;
 	bool mInitDone = false;
+
+    unsigned char clientSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
+    unsigned char serverSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
 };
 
 } // namespace rtc

+ 1 - 1
src/dtlstransport.cpp

@@ -177,8 +177,8 @@ void DtlsTransport::runRecvLoop() {
 	// Receive loop
 	try {
 		PLOG_INFO << "DTLS handshake finished";
+        postHandshake();
 		changeState(State::Connected);
-		postHandshake();
 
 		const size_t bufferSize = maxMtu;
 		char buffer[bufferSize];

+ 17 - 5
src/peerconnection.cpp

@@ -162,8 +162,8 @@ void PeerConnection::setRemoteDescription(Description description) {
 
 	for (const auto &candidate : remoteCandidates)
 		addRemoteCandidate(candidate);
-	if (std::atomic_load(&mIceTransport)) {
-            openTracks();
+	if (auto transport = std::atomic_load(&mDtlsTransport); transport && transport->state() == rtc::DtlsTransport::State::Connected) {
+        openTracks();
     }
 }
 
@@ -694,11 +694,23 @@ void PeerConnection::openTracks() {
 	if (auto transport = std::atomic_load(&mDtlsTransport)) {
 		auto srtpTransport = std::reinterpret_pointer_cast<DtlsSrtpTransport>(transport);
 		std::shared_lock lock(mTracksMutex); // read-only
-		for (auto it = mTracks.begin(); it != mTracks.end(); ++it)
-			if (auto track = it->second.lock()) {
-			    if (!track->isOpen())
+//		for (auto it = mTracks.begin(); it != mTracks.end(); ++it)
+        for (unsigned int i = 0; i < mTrackLines.size(); i++) {
+            if (auto track = mTrackLines[i].lock()) {
+                if (!track->isOpen()) {
+//                    if (track->description().direction() == rtc::Description::Direction::RecvOnly || track->description().direction() == rtc::Description::Direction::SendRecv)
+//                        srtpTransport->addInboundSSRC(0);
+//                    if (track->description().direction() == rtc::Description::Direction::SendOnly || track->description().direction() == rtc::Description::Direction::SendRecv)
+
+                    for (auto ssrc : track->description().getSSRCs())
+                        srtpTransport->addSSRC(ssrc);
+                    for (auto ssrc : std::get<rtc::Description::Media *>(remoteDescription()->media(i))->getSSRCs())
+                        srtpTransport->addSSRC(ssrc);
+
                     track->open(srtpTransport);
+                }
             }
+        }
 	}
 #endif
 }