Browse Source

Merge pull request #668 from SE2Dev/optimize

Remove Extra Lookup in PeerConnection::forwardMedia()
Paul-Louis Ageneau 3 years ago
parent
commit
e42d92c113
5 changed files with 76 additions and 74 deletions
  1. 3 3
      include/rtc/description.hpp
  2. 3 3
      src/description.cpp
  3. 48 61
      src/impl/peerconnection.cpp
  4. 5 5
      src/impl/peerconnection.hpp
  5. 17 2
      test/track.cpp

+ 3 - 3
include/rtc/description.hpp

@@ -191,9 +191,9 @@ public:
 		void removeSSRC(uint32_t ssrc);
 		void removeSSRC(uint32_t ssrc);
 		void replaceSSRC(uint32_t old, uint32_t ssrc, optional<string> name,
 		void replaceSSRC(uint32_t old, uint32_t ssrc, optional<string> name,
 		                 optional<string> msid = nullopt, optional<string> trackID = nullopt);
 		                 optional<string> msid = nullopt, optional<string> trackID = nullopt);
-		bool hasSSRC(uint32_t ssrc);
-		std::vector<uint32_t> getSSRCs();
-		std::optional<std::string> getCNameForSsrc(uint32_t ssrc);
+		bool hasSSRC(uint32_t ssrc) const;
+		std::vector<uint32_t> getSSRCs() const;
+		std::optional<std::string> getCNameForSsrc(uint32_t ssrc) const;
 
 
 		int bitrate() const;
 		int bitrate() const;
 		void setBitrate(int bitrate);
 		void setBitrate(int bitrate);

+ 3 - 3
src/description.cpp

@@ -776,7 +776,7 @@ void Description::Media::replaceSSRC(uint32_t old, uint32_t ssrc, optional<strin
 	addSSRC(ssrc, std::move(name), std::move(msid), std::move(trackID));
 	addSSRC(ssrc, std::move(name), std::move(msid), std::move(trackID));
 }
 }
 
 
-bool Description::Media::hasSSRC(uint32_t ssrc) {
+bool Description::Media::hasSSRC(uint32_t ssrc) const {
 	return std::find(mSsrcs.begin(), mSsrcs.end(), ssrc) != mSsrcs.end();
 	return std::find(mSsrcs.begin(), mSsrcs.end(), ssrc) != mSsrcs.end();
 }
 }
 
 
@@ -905,9 +905,9 @@ Description::Media Description::Media::reciprocate() const {
 	return reciprocated;
 	return reciprocated;
 }
 }
 
 
-std::vector<uint32_t> Description::Media::getSSRCs() { return mSsrcs; }
+std::vector<uint32_t> Description::Media::getSSRCs() const { return mSsrcs; }
 
 
-optional<string> Description::Media::getCNameForSsrc(uint32_t ssrc) {
+optional<string> Description::Media::getCNameForSsrc(uint32_t ssrc) const {
 	auto it = mCNameMap.find(ssrc);
 	auto it = mCNameMap.find(ssrc);
 	if (it != mCNameMap.end()) {
 	if (it != mCNameMap.end()) {
 		return it->second;
 		return it->second;

+ 48 - 61
src/impl/peerconnection.cpp

@@ -517,12 +517,11 @@ void PeerConnection::forwardMedia(message_ptr message) {
 		}
 		}
 
 
 		if (!ssrcs.empty()) {
 		if (!ssrcs.empty()) {
+			std::shared_lock lock(mTracksMutex); // read-only
 			for (uint32_t ssrc : ssrcs) {
 			for (uint32_t ssrc : ssrcs) {
-				if (auto mid = getMidFromSsrc(ssrc)) {
-					std::shared_lock lock(mTracksMutex); // read-only
-					if (auto it = mTracks.find(*mid); it != mTracks.end())
-						if (auto track = it->second.lock())
-							track->incoming(message);
+				if (auto it = mTracksBySsrc.find(ssrc); it != mTracksBySsrc.end()) {
+					if (auto track = it->second.lock())
+						track->incoming(message);
 				}
 				}
 			}
 			}
 			return;
 			return;
@@ -530,11 +529,11 @@ void PeerConnection::forwardMedia(message_ptr message) {
 	}
 	}
 
 
 	uint32_t ssrc = uint32_t(message->stream);
 	uint32_t ssrc = uint32_t(message->stream);
-	if (auto mid = getMidFromSsrc(ssrc)) {
-		std::shared_lock lock(mTracksMutex); // read-only
-		if (auto it = mTracks.find(*mid); it != mTracks.end())
-			if (auto track = it->second.lock())
-				track->incoming(message);
+
+	std::shared_lock lock(mTracksMutex); // read-only
+	if (auto it = mTracksBySsrc.find(ssrc); it != mTracksBySsrc.end()) {
+		if (auto track = it->second.lock())
+			track->incoming(message);
 	} else {
 	} else {
 		/*
 		/*
 		 * TODO: So the problem is that when stop sending streams, we stop getting report blocks for
 		 * TODO: So the problem is that when stop sending streams, we stop getting report blocks for
@@ -547,57 +546,6 @@ void PeerConnection::forwardMedia(message_ptr message) {
 	}
 	}
 }
 }
 
 
-optional<std::string> PeerConnection::getMidFromSsrc(uint32_t ssrc) {
-	if (auto it = mMidFromSsrc.find(ssrc); it != mMidFromSsrc.end())
-		return it->second;
-
-	{
-		std::lock_guard lock(mRemoteDescriptionMutex);
-		if (!mRemoteDescription)
-			return nullopt;
-
-		for (unsigned int i = 0; i < mRemoteDescription->mediaCount(); ++i) {
-			if (auto found =
-			        std::visit(rtc::overloaded{[&](Description::Application *) -> optional<string> {
-				                                   return std::nullopt;
-			                                   },
-			                                   [&](Description::Media *media) -> optional<string> {
-				                                   return media->hasSSRC(ssrc)
-				                                              ? std::make_optional(media->mid())
-				                                              : nullopt;
-			                                   }},
-			                   mRemoteDescription->media(i))) {
-
-				mMidFromSsrc.emplace(ssrc, *found);
-				return *found;
-			}
-		}
-	}
-	{
-		std::lock_guard lock(mLocalDescriptionMutex);
-		if (!mLocalDescription)
-			return nullopt;
-		for (unsigned int i = 0; i < mLocalDescription->mediaCount(); ++i) {
-			if (auto found =
-			        std::visit(rtc::overloaded{[&](Description::Application *) -> optional<string> {
-				                                   return std::nullopt;
-			                                   },
-			                                   [&](Description::Media *media) -> optional<string> {
-				                                   return media->hasSSRC(ssrc)
-				                                              ? std::make_optional(media->mid())
-				                                              : nullopt;
-			                                   }},
-			                   mLocalDescription->media(i))) {
-
-				mMidFromSsrc.emplace(ssrc, *found);
-				return *found;
-			}
-		}
-	}
-
-	return nullopt;
-}
-
 void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
 void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
 	if (auto channel = findDataChannel(stream))
 	if (auto channel = findDataChannel(stream))
 		channel->triggerBufferedAmount(amount);
 		channel->triggerBufferedAmount(amount);
@@ -993,6 +941,8 @@ void PeerConnection::processLocalDescription(Description description) {
 	if (description.mediaCount() == 0)
 	if (description.mediaCount() == 0)
 		throw std::logic_error("Local description has no media line");
 		throw std::logic_error("Local description has no media line");
 
 
+	updateTrackSsrcCache(description);
+
 	{
 	{
 		// Set as local description
 		// Set as local description
 		std::lock_guard lock(mLocalDescriptionMutex);
 		std::lock_guard lock(mLocalDescriptionMutex);
@@ -1037,6 +987,8 @@ void PeerConnection::processLocalCandidate(Candidate candidate) {
 }
 }
 
 
 void PeerConnection::processRemoteDescription(Description description) {
 void PeerConnection::processRemoteDescription(Description description) {
+	updateTrackSsrcCache(description);
+
 	{
 	{
 		// Set as remote description
 		// Set as remote description
 		std::lock_guard lock(mRemoteDescriptionMutex);
 		std::lock_guard lock(mRemoteDescriptionMutex);
@@ -1218,4 +1170,39 @@ void PeerConnection::resetCallbacks() {
 	gatheringStateChangeCallback = nullptr;
 	gatheringStateChangeCallback = nullptr;
 }
 }
 
 
+void PeerConnection::updateTrackSsrcCache(const Description &description) {
+	std::unique_lock lock(mTracksMutex); // for safely writing to mTracksBySsrc
+
+	// Setup SSRC -> Track mapping
+	for (unsigned int i = 0; i < description.mediaCount(); ++i)
+		std::visit( // ssrc -> track mapping
+		    rtc::overloaded{
+		        [&](Description::Application const *) { return; },
+		        [&](Description::Media const *media) {
+			        const auto ssrcs = media->getSSRCs();
+
+			        // Note: We don't want to lock (or do any other lookups), if we
+			        // already know there's no SSRCs to loop over.
+			        if (ssrcs.size() <= 0) {
+				        return;
+			        }
+
+			        std::shared_ptr<Track> track{nullptr};
+			        if (auto it = mTracks.find(media->mid()); it != mTracks.end())
+				        if (auto track_for_mid = it->second.lock())
+					        track = track_for_mid;
+
+			        if (!track) {
+				        // Unable to find track for MID
+				        return;
+			        }
+
+			        for (auto ssrc : ssrcs) {
+				        mTracksBySsrc.emplace(ssrc, track);
+			        }
+		        },
+		    },
+		    description.media(i));
+}
+
 } // namespace rtc::impl
 } // namespace rtc::impl

+ 5 - 5
src/impl/peerconnection.hpp

@@ -65,7 +65,6 @@ struct PeerConnection : std::enable_shared_from_this<PeerConnection> {
 	void forwardMessage(message_ptr message);
 	void forwardMessage(message_ptr message);
 	void forwardMedia(message_ptr message);
 	void forwardMedia(message_ptr message);
 	void forwardBufferedAmount(uint16_t stream, size_t amount);
 	void forwardBufferedAmount(uint16_t stream, size_t amount);
-	optional<string> getMidFromSsrc(uint32_t ssrc);
 
 
 	shared_ptr<DataChannel> emplaceDataChannel(string label, DataChannelInit init);
 	shared_ptr<DataChannel> emplaceDataChannel(string label, DataChannelInit init);
 	shared_ptr<DataChannel> findDataChannel(uint16_t stream);
 	shared_ptr<DataChannel> findDataChannel(uint16_t stream);
@@ -126,6 +125,8 @@ struct PeerConnection : std::enable_shared_from_this<PeerConnection> {
 	synchronized_callback<shared_ptr<rtc::Track>> trackCallback;
 	synchronized_callback<shared_ptr<rtc::Track>> trackCallback;
 
 
 private:
 private:
+	void updateTrackSsrcCache(const Description &description);
+
 	const init_token mInitToken = Init::Instance().token();
 	const init_token mInitToken = Init::Instance().token();
 	const future_certificate_ptr mCertificate;
 	const future_certificate_ptr mCertificate;
 
 
@@ -142,14 +143,13 @@ private:
 	std::vector<weak_ptr<DataChannel>> mUnassignedDataChannels;
 	std::vector<weak_ptr<DataChannel>> mUnassignedDataChannels;
 	std::shared_mutex mDataChannelsMutex;
 	std::shared_mutex mDataChannelsMutex;
 
 
-	std::unordered_map<string, weak_ptr<Track>> mTracks; // by mid
-	std::vector<weak_ptr<Track>> mTrackLines;            // by SDP order
+	std::unordered_map<string, weak_ptr<Track>> mTracks;         // by mid
+	std::unordered_map<uint32_t, weak_ptr<Track>> mTracksBySsrc; // by SSRC
+	std::vector<weak_ptr<Track>> mTrackLines;                    // by SDP order
 	std::shared_mutex mTracksMutex;
 	std::shared_mutex mTracksMutex;
 
 
 	Queue<shared_ptr<DataChannel>> mPendingDataChannels;
 	Queue<shared_ptr<DataChannel>> mPendingDataChannels;
 	Queue<shared_ptr<Track>> mPendingTracks;
 	Queue<shared_ptr<Track>> mPendingTracks;
-
-	std::unordered_map<uint32_t, string> mMidFromSsrc; // cache
 };
 };
 
 
 } // namespace rtc::impl
 } // namespace rtc::impl

+ 17 - 2
test/track.cpp

@@ -81,6 +81,7 @@ void test_track() {
 
 
 	shared_ptr<Track> t2;
 	shared_ptr<Track> t2;
 	string newTrackMid;
 	string newTrackMid;
+	Description::Video media;
 	pc2.onTrack([&t2, &newTrackMid](shared_ptr<Track> t) {
 	pc2.onTrack([&t2, &newTrackMid](shared_ptr<Track> t) {
 		string mid = t->mid();
 		string mid = t->mid();
 		cout << "Track 2: Received track with mid \"" << mid << "\"" << endl;
 		cout << "Track 2: Received track with mid \"" << mid << "\"" << endl;
@@ -99,7 +100,13 @@ void test_track() {
 
 
 	// Test opening a track
 	// Test opening a track
 	newTrackMid = "test";
 	newTrackMid = "test";
-	auto t1 = pc1.addTrack(Description::Video(newTrackMid));
+
+	media = Description::Video(newTrackMid, Description::Direction::SendOnly);
+	media.addH264Codec(96);
+	media.setBitrate(3000);
+	media.addSSRC(1234, "video-send");
+
+	auto t1 = pc1.addTrack(media);
 
 
 	pc1.setLocalDescription();
 	pc1.setLocalDescription();
 
 
@@ -117,7 +124,15 @@ void test_track() {
 
 
 	// Test renegotiation
 	// Test renegotiation
 	newTrackMid = "added";
 	newTrackMid = "added";
-	t1 = pc1.addTrack(Description::Video(newTrackMid));
+
+	media = Description::Video(newTrackMid, Description::Direction::SendOnly);
+	media.addH264Codec(96);
+	media.setBitrate(3000);
+	media.addSSRC(2468, "video-send");
+
+	// NOTE: Overwriting the old shared_ptr for t1 will cause it's respective
+	//       track to be dropped (so it's SSRCs won't be on the description next time)
+	t1 = pc1.addTrack(media);
 
 
 	pc1.setLocalDescription();
 	pc1.setLocalDescription();