Pārlūkot izejas kodu

Integrated Track into PeerConnection

Paul-Louis Ageneau 5 gadi atpakaļ
vecāks
revīzija
749836b1fe

+ 15 - 13
include/rtc/peerconnection.hpp

@@ -28,6 +28,7 @@
 #include "message.hpp"
 #include "message.hpp"
 #include "reliability.hpp"
 #include "reliability.hpp"
 #include "rtc.hpp"
 #include "rtc.hpp"
+#include "track.hpp"
 
 
 #include <atomic>
 #include <atomic>
 #include <functional>
 #include <functional>
@@ -85,8 +86,8 @@ public:
 	                          std::optional<Description> mediaDescription = nullopt);
 	                          std::optional<Description> mediaDescription = nullopt);
 	void addRemoteCandidate(Candidate candidate);
 	void addRemoteCandidate(Candidate candidate);
 
 
-	std::shared_ptr<DataChannel> createDataChannel(const string &label, const string &protocol = "",
-	                                               const Reliability &reliability = {});
+	std::shared_ptr<DataChannel> createDataChannel(string label, string protocol = "",
+	                                               Reliability reliability = {});
 
 
 	void onDataChannel(std::function<void(std::shared_ptr<DataChannel> dataChannel)> callback);
 	void onDataChannel(std::function<void(std::shared_ptr<DataChannel> dataChannel)> callback);
 	void onLocalDescription(std::function<void(Description description)> callback);
 	void onLocalDescription(std::function<void(Description description)> callback);
@@ -100,12 +101,10 @@ public:
 	size_t bytesReceived();
 	size_t bytesReceived();
 	std::optional<std::chrono::milliseconds> rtt();
 	std::optional<std::chrono::milliseconds> rtt();
 
 
-	// Media
+	// Media support requires compilation with SRTP
 	bool hasMedia() const;
 	bool hasMedia() const;
-	void sendMedia(binary packet);
-	void sendMedia(const byte *packet, size_t size);
-
-	void onMedia(std::function<void(binary)> callback);
+	std::shared_ptr<Track> createTrack(string mid);
+	void onTrack(std::function<void(std::shared_ptr<Track> track)> callback);
 
 
 	// libnice only
 	// libnice only
 	bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
 	bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
@@ -122,18 +121,20 @@ private:
 	void forwardMedia(message_ptr message);
 	void forwardMedia(message_ptr message);
 	void forwardBufferedAmount(uint16_t stream, size_t amount);
 	void forwardBufferedAmount(uint16_t stream, size_t amount);
 
 
-	std::shared_ptr<DataChannel> emplaceDataChannel(Description::Role role, const string &label,
-	                                                const string &protocol,
-	                                                const Reliability &reliability);
+	std::shared_ptr<DataChannel> emplaceDataChannel(Description::Role role, string label,
+	                                                string protocol, Reliability reliability);
 	std::shared_ptr<DataChannel> findDataChannel(uint16_t stream);
 	std::shared_ptr<DataChannel> findDataChannel(uint16_t stream);
 	void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
 	void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
 	void openDataChannels();
 	void openDataChannels();
 	void closeDataChannels();
 	void closeDataChannels();
 	void remoteCloseDataChannels();
 	void remoteCloseDataChannels();
 
 
+	void openTracks();
+
 	void processLocalDescription(Description description);
 	void processLocalDescription(Description description);
 	void processLocalCandidate(Candidate candidate);
 	void processLocalCandidate(Candidate candidate);
 	void triggerDataChannel(std::weak_ptr<DataChannel> weakDataChannel);
 	void triggerDataChannel(std::weak_ptr<DataChannel> weakDataChannel);
+	void triggerTrack(std::weak_ptr<Track> weakTrack);
 	bool changeState(State state);
 	bool changeState(State state);
 	bool changeGatheringState(GatheringState state);
 	bool changeGatheringState(GatheringState state);
 
 
@@ -153,8 +154,9 @@ private:
 	std::shared_ptr<DtlsTransport> mDtlsTransport;
 	std::shared_ptr<DtlsTransport> mDtlsTransport;
 	std::shared_ptr<SctpTransport> mSctpTransport;
 	std::shared_ptr<SctpTransport> mSctpTransport;
 
 
-	std::unordered_map<unsigned int, std::weak_ptr<DataChannel>> mDataChannels;
-	std::shared_mutex mDataChannelsMutex;
+	std::unordered_map<unsigned int, std::weak_ptr<DataChannel>> mDataChannels; // by stream ID
+	std::unordered_map<string, std::weak_ptr<Track>> mTracks;                   // by mid
+	std::shared_mutex mDataChannelsMutex, mTracksMutex;
 
 
 	std::atomic<State> mState;
 	std::atomic<State> mState;
 	std::atomic<GatheringState> mGatheringState;
 	std::atomic<GatheringState> mGatheringState;
@@ -164,7 +166,7 @@ private:
 	synchronized_callback<Candidate> mLocalCandidateCallback;
 	synchronized_callback<Candidate> mLocalCandidateCallback;
 	synchronized_callback<State> mStateChangeCallback;
 	synchronized_callback<State> mStateChangeCallback;
 	synchronized_callback<GatheringState> mGatheringStateChangeCallback;
 	synchronized_callback<GatheringState> mGatheringStateChangeCallback;
-	synchronized_callback<binary> mMediaCallback;
+	synchronized_callback<std::shared_ptr<Track>> mTrackCallback;
 };
 };
 
 
 } // namespace rtc
 } // namespace rtc

+ 8 - 3
include/rtc/track.hpp

@@ -29,12 +29,13 @@
 
 
 namespace rtc {
 namespace rtc {
 
 
-class PeerConnection;
+#if RTC_ENABLE_MEDIA
 class DtlsSrtpTransport;
 class DtlsSrtpTransport;
+#endif
 
 
 class Track final : public std::enable_shared_from_this<Track>, public Channel {
 class Track final : public std::enable_shared_from_this<Track>, public Channel {
 public:
 public:
-	Track(string mid, std::shared_ptr<DtlsSrtpTransport> transport = nullptr);
+	Track(string mid);
 	~Track() = default;
 	~Track() = default;
 
 
 	string mid() const;
 	string mid() const;
@@ -52,12 +53,15 @@ public:
 	std::optional<message_variant> receive() override;
 	std::optional<message_variant> receive() override;
 
 
 private:
 private:
+#if RTC_ENABLE_MEDIA
 	void open(std::shared_ptr<DtlsSrtpTransport> transport);
 	void open(std::shared_ptr<DtlsSrtpTransport> transport);
+	std::weak_ptr<DtlsSrtpTransport> mDtlsSrtpTransport;
+#endif
+
 	bool outgoing(message_ptr message);
 	bool outgoing(message_ptr message);
 	void incoming(message_ptr message);
 	void incoming(message_ptr message);
 
 
 	const string mMid;
 	const string mMid;
-	std::weak_ptr<DtlsSrtpTransport> mDtlsSrtpTransport;
 	std::atomic<bool> mIsClosed = false;
 	std::atomic<bool> mIsClosed = false;
 
 
 	Queue<message_ptr> mRecvQueue;
 	Queue<message_ptr> mRecvQueue;
@@ -68,3 +72,4 @@ private:
 } // namespace rtc
 } // namespace rtc
 
 
 #endif
 #endif
+

+ 3 - 3
src/datachannel.cpp

@@ -198,13 +198,13 @@ bool DataChannel::outgoing(message_ptr message) {
 	if (mIsClosed)
 	if (mIsClosed)
 		throw std::runtime_error("DataChannel is closed");
 		throw std::runtime_error("DataChannel is closed");
 
 
+	if (message->size() > maxMessageSize())
+		throw std::runtime_error("Message size exceeds limit");
+
 	auto transport = mSctpTransport.lock();
 	auto transport = mSctpTransport.lock();
 	if (!transport)
 	if (!transport)
 		throw std::runtime_error("DataChannel transport is not open");
 		throw std::runtime_error("DataChannel transport is not open");
 
 
-	if (message->size() > maxMessageSize())
-		throw std::runtime_error("Message size exceeds limit");
-
 	// Before the ACK has been received on a DataChannel, all messages must be sent ordered
 	// Before the ACK has been received on a DataChannel, all messages must be sent ordered
 	message->reliability = mIsOpen ? mReliability : nullptr;
 	message->reliability = mIsOpen ? mReliability : nullptr;
 	message->stream = mStream;
 	message->stream = mStream;

+ 4 - 0
src/dtlssrtptransport.cpp

@@ -175,6 +175,8 @@ void DtlsSrtpTransport::incoming(message_ptr message) {
 				return;
 				return;
 			}
 			}
 			PLOG_VERBOSE << "Unprotected SRTCP packet, size=" << size;
 			PLOG_VERBOSE << "Unprotected SRTCP packet, size=" << size;
+			message->type = Message::Type::Control;
+			message->stream = to_integer<uint8_t>(*(message->begin() + 1)); // Payload Type
 		} else {
 		} else {
 			PLOG_VERBOSE << "Incoming SRTP packet, size=" << size;
 			PLOG_VERBOSE << "Incoming SRTP packet, size=" << size;
 			if (srtp_err_status_t err = srtp_unprotect(mSrtpIn, message->data(), &size)) {
 			if (srtp_err_status_t err = srtp_unprotect(mSrtpIn, message->data(), &size)) {
@@ -187,6 +189,8 @@ void DtlsSrtpTransport::incoming(message_ptr message) {
 				return;
 				return;
 			}
 			}
 			PLOG_VERBOSE << "Unprotected SRTP packet, size=" << size;
 			PLOG_VERBOSE << "Unprotected SRTP packet, size=" << size;
+			message->type = Message::Type::Binary;
+			message->stream = value2; // Payload Type
 		}
 		}
 
 
 		message->resize(size);
 		message->resize(size);

+ 59 - 43
src/peerconnection.cpp

@@ -185,9 +185,8 @@ std::optional<string> PeerConnection::remoteAddress() const {
 	return iceTransport ? iceTransport->getRemoteAddress() : nullopt;
 	return iceTransport ? iceTransport->getRemoteAddress() : nullopt;
 }
 }
 
 
-shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
-                                                          const string &protocol,
-                                                          const Reliability &reliability) {
+shared_ptr<DataChannel> PeerConnection::createDataChannel(string label, string protocol,
+                                                          Reliability reliability) {
 	// RFC 5763: The answerer MUST use either a setup attribute value of setup:active or
 	// RFC 5763: The answerer MUST use either a setup attribute value of setup:active or
 	// setup:passive. [...] Thus, setup:active is RECOMMENDED.
 	// setup:passive. [...] Thus, setup:active is RECOMMENDED.
 	// See https://tools.ietf.org/html/rfc5763#section-5
 	// See https://tools.ietf.org/html/rfc5763#section-5
@@ -195,7 +194,8 @@ shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
 	auto iceTransport = std::atomic_load(&mIceTransport);
 	auto iceTransport = std::atomic_load(&mIceTransport);
 	auto role = iceTransport ? iceTransport->role() : Description::Role::Passive;
 	auto role = iceTransport ? iceTransport->role() : Description::Role::Passive;
 
 
-	auto channel = emplaceDataChannel(role, label, protocol, reliability);
+	auto channel =
+	    emplaceDataChannel(role, std::move(label), std::move(protocol), std::move(reliability));
 
 
 	if (!iceTransport) {
 	if (!iceTransport) {
 		// RFC 5763: The endpoint that is the offerer MUST use the setup attribute value of
 		// RFC 5763: The endpoint that is the offerer MUST use the setup attribute value of
@@ -239,29 +239,17 @@ bool PeerConnection::hasMedia() const {
 	return (local && local->hasMedia()) || (remote && remote->hasMedia());
 	return (local && local->hasMedia()) || (remote && remote->hasMedia());
 }
 }
 
 
-void PeerConnection::sendMedia(binary packet) {
-	outgoingMedia(make_message(std::move(packet), Message::Binary));
-}
+std::shared_ptr<Track> PeerConnection::createTrack(string mid) {
+	if (localDescription())
+		throw std::logic_error("Tracks must be created before local description");
 
 
-void PeerConnection::sendMedia(const byte *packet, size_t size) {
-	outgoingMedia(make_message(packet, packet + size, Message::Binary));
+	auto track = std::make_shared<Track>(mid);
+	mTracks.emplace(std::make_pair(mid, track));
+	return track;
 }
 }
 
 
-void PeerConnection::onMedia(std::function<void(binary)> callback) { mMediaCallback = callback; }
-
-void PeerConnection::outgoingMedia([[maybe_unused]] message_ptr message) {
-	if (!hasMedia())
-		throw std::runtime_error("PeerConnection has no media support");
-
-#if RTC_ENABLE_MEDIA
-	auto transport = std::atomic_load(&mDtlsTransport);
-	if (!transport)
-		throw std::runtime_error("PeerConnection is not open");
-
-	std::dynamic_pointer_cast<DtlsSrtpTransport>(transport)->sendMedia(message);
-#else
-	PLOG_WARNING << "Ignoring sent media (not compiled with SRTP support)";
-#endif
+void PeerConnection::onTrack(std::function<void(std::shared_ptr<Track>)> callback) {
+	mTrackCallback = callback;
 }
 }
 
 
 shared_ptr<IceTransport> PeerConnection::initIceTransport(Description::Role role) {
 shared_ptr<IceTransport> PeerConnection::initIceTransport(Description::Role role) {
@@ -343,6 +331,7 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
 			switch (state) {
 			switch (state) {
 			case DtlsTransport::State::Connected:
 			case DtlsTransport::State::Connected:
 				initSctpTransport();
 				initSctpTransport();
+				openTracks();
 				break;
 				break;
 			case DtlsTransport::State::Failed:
 			case DtlsTransport::State::Failed:
 				changeState(State::Failed);
 				changeState(State::Failed);
@@ -525,8 +514,16 @@ void PeerConnection::forwardMessage(message_ptr message) {
 }
 }
 
 
 void PeerConnection::forwardMedia(message_ptr message) {
 void PeerConnection::forwardMedia(message_ptr message) {
-	if (message)
-		mMediaCallback(std::move(*message));
+	if (!message)
+		return;
+
+	string mid;
+	// TODO: stream (PT) to mid
+
+	std::shared_lock lock(mTracksMutex); // read-only
+	if (auto it = mTracks.find(mid); it != mTracks.end())
+		if (auto track = it->second.lock())
+			track->incoming(message);
 }
 }
 
 
 void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
 void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
@@ -534,10 +531,9 @@ void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
 		channel->triggerBufferedAmount(amount);
 		channel->triggerBufferedAmount(amount);
 }
 }
 
 
-shared_ptr<DataChannel> PeerConnection::emplaceDataChannel(Description::Role role,
-                                                           const string &label,
-                                                           const string &protocol,
-                                                           const Reliability &reliability) {
+shared_ptr<DataChannel> PeerConnection::emplaceDataChannel(Description::Role role, string label,
+                                                           string protocol,
+                                                           Reliability reliability) {
 	// The active side must use streams with even identifiers, whereas the passive side must use
 	// The active side must use streams with even identifiers, whereas the passive side must use
 	// streams with odd identifiers.
 	// streams with odd identifiers.
 	// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09#section-6
 	// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09#section-6
@@ -548,8 +544,8 @@ shared_ptr<DataChannel> PeerConnection::emplaceDataChannel(Description::Role rol
 		if (stream >= 65535)
 		if (stream >= 65535)
 			throw std::runtime_error("Too many DataChannels");
 			throw std::runtime_error("Too many DataChannels");
 	}
 	}
-	auto channel =
-	    std::make_shared<DataChannel>(shared_from_this(), stream, label, protocol, reliability);
+	auto channel = std::make_shared<DataChannel>(shared_from_this(), stream, std::move(label),
+	                                             std::move(protocol), std::move(reliability));
 	mDataChannels.emplace(std::make_pair(stream, channel));
 	mDataChannels.emplace(std::make_pair(stream, channel));
 	return channel;
 	return channel;
 }
 }
@@ -598,6 +594,21 @@ void PeerConnection::openDataChannels() {
 		iterateDataChannels([&](shared_ptr<DataChannel> channel) { channel->open(transport); });
 		iterateDataChannels([&](shared_ptr<DataChannel> channel) { channel->open(transport); });
 }
 }
 
 
+void PeerConnection::openTracks() {
+#if RTC_ENABLE_MEDIA
+	if (!hasMedia())
+		return;
+
+	if (auto transport = std::atomic_load(&mDtlsTransport)) {
+		auto srtpTransport = std::reinterpret_pointer_cast<DtlsSrtpTransport>(transport);
+		std::shared_lock lock(mTracksMutex); // read-only
+		for (auto it = mTracks.begin(); it != mTracks.end(); ++it)
+			if (auto track = it->second.lock())
+				track->open(srtpTransport);
+	}
+#endif
+}
+
 void PeerConnection::closeDataChannels() {
 void PeerConnection::closeDataChannels() {
 	iterateDataChannels([&](shared_ptr<DataChannel> channel) { channel->close(); });
 	iterateDataChannels([&](shared_ptr<DataChannel> channel) { channel->close(); });
 }
 }
@@ -614,18 +625,15 @@ void PeerConnection::processLocalDescription(Description description) {
 	    remoteSctpPort = remote->sctpPort();
 	    remoteSctpPort = remote->sctpPort();
 	}
 	}
 
 
-	auto certificate = mCertificate.get(); // wait for certificate if not ready
+	if (remoteDataMid)
+		description.setDataMid(*remoteDataMid);
 
 
-	{
-		std::lock_guard lock(mLocalDescriptionMutex);
-		mLocalDescription.emplace(std::move(description));
-		if (remoteDataMid)
-			mLocalDescription->setDataMid(*remoteDataMid);
-
-		mLocalDescription->setFingerprint(certificate->fingerprint());
-		mLocalDescription->setSctpPort(remoteSctpPort.value_or(DEFAULT_SCTP_PORT));
-		mLocalDescription->setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE);
-	}
+	description.setSctpPort(remoteSctpPort.value_or(DEFAULT_SCTP_PORT));
+	description.setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE);
+	description.setFingerprint(mCertificate.get()->fingerprint()); // wait for certificate
+
+	std::lock_guard lock(mLocalDescriptionMutex);
+	mLocalDescription.emplace(std::move(description));
 
 
 	mProcessor->enqueue([this, description = *mLocalDescription]() {
 	mProcessor->enqueue([this, description = *mLocalDescription]() {
 		mLocalDescriptionCallback(std::move(description));
 		mLocalDescriptionCallback(std::move(description));
@@ -653,6 +661,14 @@ void PeerConnection::triggerDataChannel(weak_ptr<DataChannel> weakDataChannel) {
 	    [this, dataChannel = std::move(dataChannel)]() { mDataChannelCallback(dataChannel); });
 	    [this, dataChannel = std::move(dataChannel)]() { mDataChannelCallback(dataChannel); });
 }
 }
 
 
+void PeerConnection::triggerTrack(std::weak_ptr<Track> weakTrack) {
+	auto track = weakTrack.lock();
+	if (!track)
+		return;
+
+	mProcessor->enqueue([this, track = std::move(track)]() { mTrackCallback(track); });
+}
+
 bool PeerConnection::changeState(State state) {
 bool PeerConnection::changeState(State state) {
 	State current;
 	State current;
 	do {
 	do {

+ 23 - 12
src/track.cpp

@@ -25,12 +25,7 @@ namespace rtc {
 using std::shared_ptr;
 using std::shared_ptr;
 using std::weak_ptr;
 using std::weak_ptr;
 
 
-Track::Track(string mid, shared_ptr<DtlsSrtpTransport> transport)
-    : mMid(std::move(mid)), mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {
-
-	if (transport)
-		open(transport);
-}
+Track::Track(string mid) : mMid(std::move(mid)), mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {}
 
 
 string Track::mid() const { return mMid; }
 string Track::mid() const { return mMid; }
 
 
@@ -52,7 +47,13 @@ std::optional<message_variant> Track::receive() {
 	return nullopt;
 	return nullopt;
 }
 }
 
 
-bool Track::isOpen(void) const { return !mIsClosed && mDtlsSrtpTransport.lock(); }
+bool Track::isOpen(void) const {
+#if RTC_ENABLE_MEDIA
+	return !mIsClosed && mDtlsSrtpTransport.lock();
+#else
+	return !mIsClosed;
+#endif
+}
 
 
 bool Track::isClosed(void) const { return mIsClosed; }
 bool Track::isClosed(void) const { return mIsClosed; }
 
 
@@ -60,22 +61,31 @@ size_t Track::maxMessageSize() const {
 	return 65535 - 12 - 4; // SRTP/UDP
 	return 65535 - 12 - 4; // SRTP/UDP
 }
 }
 
 
-size_t Track::availableAmount() const { return mRecvQueue.amount(); }
+size_t Track::availableAmount() const {
+	return mRecvQueue.amount();
+}
 
 
+#if RTC_ENABLE_MEDIA
 void Track::open(shared_ptr<DtlsSrtpTransport> transport) { mDtlsSrtpTransport = transport; }
 void Track::open(shared_ptr<DtlsSrtpTransport> transport) { mDtlsSrtpTransport = transport; }
+#endif
 
 
 bool Track::outgoing(message_ptr message) {
 bool Track::outgoing(message_ptr message) {
 	if (mIsClosed)
 	if (mIsClosed)
 		throw std::runtime_error("Track is closed");
 		throw std::runtime_error("Track is closed");
 
 
+	if (message->size() > maxMessageSize())
+		throw std::runtime_error("Message size exceeds limit");
+
+#if RTC_ENABLE_MEDIA
 	auto transport = mDtlsSrtpTransport.lock();
 	auto transport = mDtlsSrtpTransport.lock();
 	if (!transport)
 	if (!transport)
 		throw std::runtime_error("Track transport is not open");
 		throw std::runtime_error("Track transport is not open");
 
 
-	if (message->size() > maxMessageSize())
-		throw std::runtime_error("Message size exceeds limit");
-
-	return transport->send(message);
+	return transport->sendMedia(message);
+#else
+	PLOG_WARNING << "Ignoring track send (not compiled with SRTP support)";
+	return false;
+#endif
 }
 }
 
 
 void Track::incoming(message_ptr message) {
 void Track::incoming(message_ptr message) {
@@ -91,3 +101,4 @@ void Track::incoming(message_ptr message) {
 }
 }
 
 
 } // namespace rtc
 } // namespace rtc
+