Переглянути джерело

Refactored SDP parsing and generation to handle media

Paul-Louis Ageneau 5 роки тому
батько
коміт
9204787677
4 змінених файлів з 196 додано та 71 видалено
  1. 21 4
      include/rtc/description.hpp
  2. 2 0
      include/rtc/peerconnection.hpp
  3. 119 46
      src/description.cpp
  4. 54 21
      src/peerconnection.cpp

+ 21 - 4
include/rtc/description.hpp

@@ -42,7 +42,7 @@ public:
 	string typeString() const;
 	Role role() const;
 	string roleString() const;
-	string mid() const;
+	string dataMid() const;
 	std::optional<string> fingerprint() const;
 	std::optional<uint16_t> sctpPort() const;
 	std::optional<size_t> maxMessageSize() const;
@@ -65,11 +65,28 @@ private:
 	Type mType;
 	Role mRole;
 	string mSessionId;
-	string mMid;
 	string mIceUfrag, mIcePwd;
 	std::optional<string> mFingerprint;
-	std::optional<uint16_t> mSctpPort;
-	std::optional<size_t> mMaxMessageSize;
+
+	// Data
+	struct Data {
+		string mid;
+		std::optional<uint16_t> sctpPort;
+		std::optional<size_t> maxMessageSize;
+	};
+	Data mData;
+
+	// Media (non-data)
+	struct Media {
+		string description;
+		string mid;
+		std::vector<string> attributes;
+
+		string type() const;
+	};
+	std::map<string, Media> mMedia; // by mid
+
+	// Candidates
 	std::vector<Candidate> mCandidates;
 	bool mTrickle;
 

+ 2 - 0
include/rtc/peerconnection.hpp

@@ -75,6 +75,7 @@ public:
 	std::optional<string> localAddress() const;
 	std::optional<string> remoteAddress() const;
 
+	void setLocalDescription(Description description);
 	void setRemoteDescription(Description description);
 	void addRemoteCandidate(Candidate candidate);
 
@@ -104,6 +105,7 @@ private:
 	void endLocalCandidates();
 	bool checkFingerprint(const std::string &fingerprint) const;
 	void forwardMessage(message_ptr message);
+	void forwardMedia(message_ptr message);
 	void forwardBufferedAmount(uint16_t stream, size_t amount);
 
 	std::shared_ptr<DataChannel> emplaceDataChannel(Description::Role role, const string &label,

+ 119 - 46
src/description.cpp

@@ -29,7 +29,7 @@ using std::string;
 
 namespace {
 
-inline bool hasprefix(const string &str, const string &prefix) {
+inline bool match_prefix(const string &str, const string &prefix) {
 	return str.size() >= prefix.size() &&
 	       std::mismatch(prefix.begin(), prefix.end(), str.begin()).first == prefix.end();
 }
@@ -50,7 +50,8 @@ Description::Description(const string &sdp, const string &typeString)
 Description::Description(const string &sdp, Type type) : Description(sdp, type, Role::ActPass) {}
 
 Description::Description(const string &sdp, Type type, Role role)
-    : mType(Type::Unspec), mRole(role), mMid("0"), mIceUfrag(""), mIcePwd(""), mTrickle(true) {
+    : mType(Type::Unspec), mRole(role), mIceUfrag(""), mIcePwd(""), mTrickle(true) {
+	mData.mid = "data";
 	hintType(type);
 
 	auto seed = std::chrono::system_clock::now().time_since_epoch().count();
@@ -59,37 +60,78 @@ Description::Description(const string &sdp, Type type, Role role)
 	mSessionId = std::to_string(uniform(generator));
 
 	std::istringstream ss(sdp);
+	std::optional<Media> currentMedia;
+
 	string line;
-	while (std::getline(ss, line)) {
+	bool finished;
+	do {
+		bool finished = !std::getline(ss, line) && line.empty();
 		trim_end(line);
-		if (hasprefix(line, "a=setup:")) {
-			const string setup = line.substr(line.find(':') + 1);
-			if (setup == "active")
-				mRole = Role::Active;
-			else if (setup == "passive")
-				mRole = Role::Passive;
-			else
-				mRole = Role::ActPass;
-		} else if (hasprefix(line, "a=mid:")) {
-			mMid = line.substr(line.find(':') + 1);
-		} else if (hasprefix(line, "a=fingerprint:sha-256")) {
-			mFingerprint = line.substr(line.find(' ') + 1);
-			std::transform(mFingerprint->begin(), mFingerprint->end(), mFingerprint->begin(),
-						   [](char c) { return std::toupper(c); });
-		} else if (hasprefix(line, "a=ice-ufrag")) {
-			mIceUfrag = line.substr(line.find(':') + 1);
-		} else if (hasprefix(line, "a=ice-pwd")) {
-			mIcePwd = line.substr(line.find(':') + 1);
-		} else if (hasprefix(line, "a=sctp-port")) {
-			mSctpPort = uint16_t(std::stoul(line.substr(line.find(':') + 1)));
-		} else if (hasprefix(line, "a=max-message-size")) {
-			mMaxMessageSize = size_t(std::stoul(line.substr(line.find(':') + 1)));
-		} else if (hasprefix(line, "a=candidate")) {
-			addCandidate(Candidate(line.substr(2), mMid));
-		} else if (hasprefix(line, "a=end-of-candidates")) {
-			mTrickle = false;
+
+		// Media description line (aka m-line)
+		if (finished || match_prefix(line, "m=")) {
+			if (currentMedia) {
+				if (!currentMedia->mid.empty()) {
+					if (currentMedia->type() == "application")
+						mData.mid = currentMedia->mid;
+					else
+						mMedia.emplace(currentMedia->mid, std::move(*currentMedia));
+				} else {
+					PLOG_WARNING << "SDP \"m=\" line has no mid, ignoring";
+				}
+			}
+			if (!finished)
+				currentMedia.emplace(Media{line.substr(2)});
+
+			// Attribute line
+		} else if (match_prefix(line, "a=")) {
+			string attr = line.substr(2);
+
+			string key, value;
+			if (size_t separator = attr.find(':'); separator != string::npos) {
+				key = attr.substr(0, separator);
+				value = attr.substr(separator + 1);
+			} else {
+				key = attr;
+			}
+
+			if (key == "mid") {
+				if (currentMedia)
+					currentMedia->mid = value;
+
+			} else if (key == "setup") {
+				if (value == "active")
+					mRole = Role::Active;
+				else if (value == "passive")
+					mRole = Role::Passive;
+				else
+					mRole = Role::ActPass;
+
+			} else if (key == "fingerprint") {
+				if (match_prefix(value, "sha-256 ")) {
+					mFingerprint = value.substr(8);
+					std::transform(mFingerprint->begin(), mFingerprint->end(),
+					               mFingerprint->begin(), [](char c) { return std::toupper(c); });
+				} else {
+					PLOG_WARNING << "Unknown SDP fingerprint type: " << value;
+				}
+			} else if (key == "ice-ufrag") {
+				mIceUfrag = value;
+			} else if (key == "ice-pwd") {
+				mIcePwd = value;
+			} else if (key == "sctp-port") {
+				mData.sctpPort = uint16_t(std::stoul(value));
+			} else if (key == "max-message-size") {
+				mData.maxMessageSize = size_t(std::stoul(value));
+			} else if (key == "candidate") {
+				addCandidate(Candidate(attr, currentMedia ? currentMedia->mid : mData.mid));
+			} else if (key == "end-of-candidates") {
+				mTrickle = false;
+			} else if (currentMedia) {
+				currentMedia->attributes.emplace_back(line.substr(2));
+			}
 		}
-	}
+	} while (!finished);
 }
 
 Description::Type Description::type() const { return mType; }
@@ -100,13 +142,13 @@ Description::Role Description::role() const { return mRole; }
 
 string Description::roleString() const { return roleToString(mRole); }
 
-string Description::mid() const { return mMid; }
+string Description::dataMid() const { return mData.mid; }
 
 std::optional<string> Description::fingerprint() const { return mFingerprint; }
 
-std::optional<uint16_t> Description::sctpPort() const { return mSctpPort; }
+std::optional<uint16_t> Description::sctpPort() const { return mData.sctpPort; }
 
-std::optional<size_t> Description::maxMessageSize() const { return mMaxMessageSize; }
+std::optional<size_t> Description::maxMessageSize() const { return mData.maxMessageSize; }
 
 bool Description::trickleEnabled() const { return mTrickle; }
 
@@ -122,9 +164,9 @@ void Description::setFingerprint(string fingerprint) {
 	mFingerprint.emplace(std::move(fingerprint));
 }
 
-void Description::setSctpPort(uint16_t port) { mSctpPort.emplace(port); }
+void Description::setSctpPort(uint16_t port) { mData.sctpPort.emplace(port); }
 
-void Description::setMaxMessageSize(size_t size) { mMaxMessageSize.emplace(size); }
+void Description::setMaxMessageSize(size_t size) { mData.maxMessageSize.emplace(size); }
 
 void Description::addCandidate(Candidate candidate) {
 	mCandidates.emplace_back(std::move(candidate));
@@ -146,36 +188,67 @@ string Description::generateSdp(const string &eol) const {
 		throw std::logic_error("Fingerprint must be set to generate an SDP string");
 
 	std::ostringstream sdp;
+
+	// Header
 	sdp << "v=0" << eol;
 	sdp << "o=- " << mSessionId << " 0 IN IP4 127.0.0.1" << eol;
 	sdp << "s=-" << eol;
 	sdp << "t=0 0" << eol;
-	sdp << "a=group:BUNDLE 0" << eol;
+
+	// Bundle
+	sdp << "a=group:BUNDLE";
+	for (const auto &[mid, _] : mMedia)
+		sdp << " " << mid;
+	sdp << " " << mData.mid << eol;
+
+	// Non-data media
+	if (!mMedia.empty()) {
+		// Lip-sync
+		sdp << "a=group:LS";
+		for (const auto &[mid, _] : mMedia)
+			sdp << " " << mid;
+		sdp << eol;
+
+		// Descriptions and attributes
+		for (const auto &[_, media] : mMedia) {
+			sdp << "m=" << media.description << eol;
+			sdp << "c=IN IP4 0.0.0.0" << eol;
+			sdp << "a=mid:" << media.mid << eol;
+			for (const auto &attr : media.attributes)
+				sdp << "a=" << attr << eol;
+		}
+	}
+
+	// Data
 	sdp << "m=application 9 UDP/DTLS/SCTP webrtc-datachannel" << eol;
 	sdp << "c=IN IP4 0.0.0.0" << eol;
-	sdp << "a=ice-ufrag:" << mIceUfrag << eol;
-	sdp << "a=ice-pwd:" << mIcePwd << eol;
+	sdp << "a=mid:" << mData.mid << eol;
+	if (mData.sctpPort)
+		sdp << "a=sctp-port:" << *mData.sctpPort << eol;
+	if (mData.maxMessageSize)
+		sdp << "a=max-message-size:" << *mData.maxMessageSize << eol;
+
+	// Common
 	if (mTrickle)
 		sdp << "a=ice-options:trickle" << eol;
-	sdp << "a=mid:" << mMid << eol;
+	sdp << "a=ice-ufrag:" << mIceUfrag << eol;
+	sdp << "a=ice-pwd:" << mIcePwd << eol;
 	sdp << "a=setup:" << roleToString(mRole) << eol;
 	sdp << "a=dtls-id:1" << eol;
 	if (mFingerprint)
 		sdp << "a=fingerprint:sha-256 " << *mFingerprint << eol;
-	if (mSctpPort)
-		sdp << "a=sctp-port:" << *mSctpPort << eol;
-	if (mMaxMessageSize)
-		sdp << "a=max-message-size:" << *mMaxMessageSize << eol;
-	for (const auto &candidate : mCandidates) {
-		sdp << string(candidate) << eol;
-	}
 
+	// Candidates
+	for (const auto &candidate : mCandidates)
+		sdp << string(candidate) << eol;
 	if (!mTrickle)
 		sdp << "a=end-of-candidates" << eol;
 
 	return sdp.str();
 }
 
+string Description::Media::type() const { return description.substr(0, description.find(' ')); }
+
 Description::Type Description::stringToType(const string &typeString) {
 	if (typeString == "offer")
 		return Type::Offer;

+ 54 - 21
src/peerconnection.cpp

@@ -23,6 +23,10 @@
 #include "include.hpp"
 #include "sctptransport.hpp"
 
+#if RTC_ENABLE_MEDIA
+#include "dtlssrtptransport.hpp"
+#endif
+
 #include <thread>
 
 namespace rtc {
@@ -61,6 +65,21 @@ std::optional<Description> PeerConnection::remoteDescription() const {
 	return mRemoteDescription;
 }
 
+void PeerConnection::setLocalDescription(Description description) {
+	if (auto iceTransport = std::atomic_load(&mIceTransport)) {
+		throw std::logic_error("Local description is already set");
+	} else {
+		// RFC 5763: The endpoint that is the offerer MUST use the setup attribute value of
+		// setup:actpass.
+		// See https://tools.ietf.org/html/rfc5763#section-5
+		iceTransport = initIceTransport(Description::Role::ActPass);
+		Description localDescription = iceTransport->getLocalDescription(Description::Type::Offer);
+		localDescription.addMedia(description); // TODO
+		processLocalDescription(description);
+		iceTransport->gatherLocalCandidates();
+	}
+}
+
 void PeerConnection::setRemoteDescription(Description description) {
 	description.hintType(localDescription() ? Description::Type::Answer : Description::Type::Offer);
 	auto remoteCandidates = description.extractCandidates();
@@ -252,27 +271,37 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
 			return transport;
 
 		auto lower = std::atomic_load(&mIceTransport);
-		auto transport = std::make_shared<DtlsTransport>(
-		    lower, mCertificate, weak_bind(&PeerConnection::checkFingerprint, this, _1),
-		    [this, weak_this = weak_from_this()](DtlsTransport::State state) {
-			    auto shared_this = weak_this.lock();
-			    if (!shared_this)
-				    return;
-			    switch (state) {
-			    case DtlsTransport::State::Connected:
-				    initSctpTransport();
-				    break;
-			    case DtlsTransport::State::Failed:
-				    changeState(State::Failed);
-				    break;
-			    case DtlsTransport::State::Disconnected:
-				    changeState(State::Disconnected);
-				    break;
-			    default:
-				    // Ignore
-				    break;
-			    }
-		    });
+		auto verifierCallback = weak_bind(&PeerConnection::checkFingerprint, this, _1);
+		auto stateChangeCallback = [this,
+		                            weak_this = weak_from_this()](DtlsTransport::State state) {
+			switch (state) {
+			case DtlsTransport::State::Connected:
+				initSctpTransport();
+				break;
+			case DtlsTransport::State::Failed:
+				changeState(State::Failed);
+				break;
+			case DtlsTransport::State::Disconnected:
+				changeState(State::Disconnected);
+				break;
+			default:
+				// Ignore
+				break;
+			}
+		};
+
+		shared_ptr<DtlsTransport> transport;
+#if RTC_ENABLE_MEDIA
+		auto local = localDescription();
+		auto remote = remoteDescription();
+		if ((local && local->hasMedia()) || (remote && remote->hasMedia()))
+			transport = std::make_shared<DtlsSrtpTransport>(
+			    lower, mCertificate, verifierCallback,
+			    std::bind(&PeerConnection::forwardMedia, this, _1), stateChangeCallback);
+		else
+#endif
+			transport = std::make_shared<DtlsTransport>(lower, mCertificate, verifierCallback,
+			                                            stateChangeCallback);
 
 		std::atomic_store(&mDtlsTransport, transport);
 		if (mState == State::Closed) {
@@ -413,6 +442,10 @@ void PeerConnection::forwardMessage(message_ptr message) {
 	channel->incoming(message);
 }
 
+void PeerConnection::forwardMedia(message_ptr message) {
+	// TODO
+}
+
 void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
 	if (auto channel = findDataChannel(stream))
 		channel->triggerBufferedAmount(amount);