Browse Source

Fixed race condition with SCTP simultaneous open and added port handling

Paul-Louis Ageneau 6 years ago
parent
commit
549dca2436
6 changed files with 137 additions and 74 deletions
  1. 40 22
      src/description.cpp
  2. 4 3
      src/description.hpp
  3. 12 8
      src/peerconnection.cpp
  4. 1 1
      src/peerconnection.hpp
  5. 65 31
      src/sctptransport.cpp
  6. 15 9
      src/sctptransport.hpp

+ 40 - 22
src/description.cpp

@@ -44,7 +44,8 @@ inline void trim_end(string &str) {
 
 namespace rtc {
 
-Description::Description(Role role, const string &sdp) : mRole(role), mMid("0"), mIceUfrag("0"), mIcePwd("0") {
+Description::Description(Role role, const string &sdp)
+    : mRole(role), mMid("0"), mIceUfrag("0"), mIcePwd("0") {
 	auto seed = std::chrono::system_clock::now().time_since_epoch().count();
 	std::default_random_engine generator(seed);
 	std::uniform_int_distribution<uint32_t> uniform;
@@ -56,60 +57,78 @@ Description::Description(Role role, const string &sdp) : mRole(role), mMid("0"),
 		trim_end(line);
 		if (hasprefix(line, "a=setup:")) {
 			const string setup = line.substr(line.find(':') + 1);
-			if (setup == "active" && mRole == Role::Active) {
-				mRole = Role::Passive;
-			} else if (setup == "passive" && mRole == Role::Passive) {
+			if (setup == "active")
 				mRole = Role::Active;
-			} else { // actpass, nothing to do
-			}
+			else if (setup == "passive")
+				mRole = Role::Active;
+			else
+				mRole = Role::ActPass;
 		} else if (hasprefix(line, "a=mid:")) {
 			mMid = line.substr(line.find(':') + 1);
 		} else if (hasprefix(line, "a=fingerprint:sha-256")) {
 			mFingerprint = line.substr(line.find(' ') + 1);
+			std::transform(mFingerprint->begin(), mFingerprint->end(), mFingerprint->begin(),
+						   [](char c) { return std::toupper(c); });
 		} else if (hasprefix(line, "a=ice-ufrag")) {
 			mIceUfrag = line.substr(line.find(':') + 1);
 		} else if (hasprefix(line, "a=ice-pwd")) {
 			mIcePwd = line.substr(line.find(':') + 1);
+		} else if (hasprefix(line, "a=sctp-port")) {
+			mSctpPort = uint16_t(std::stoul(line.substr(line.find(':') + 1)));
 		}
 	}
 }
 
 Description::Role Description::role() const { return mRole; }
 
-std::optional<string> Description::fingerprint() const {
-	return mFingerprint;
+std::optional<string> Description::fingerprint() const { return mFingerprint; }
+
+std::optional<uint16_t> Description::sctpPort() const { return mSctpPort; }
+
+void Description::setFingerprint(string fingerprint) {
+	mFingerprint.emplace(std::move(fingerprint));
 }
 
-void Description::setFingerprint(const string &fingerprint) { mFingerprint = fingerprint; }
+void Description::setSctpPort(uint16_t port) { mSctpPort.emplace(port); }
 
 void Description::addCandidate(Candidate candidate) {
 	mCandidates.emplace_back(std::move(candidate));
 }
 
-void Description::addCandidate(Candidate &&candidate) {
-	mCandidates.emplace_back(std::forward<Candidate>(candidate));
-}
-
 Description::operator string() const {
 	if (!mFingerprint)
-		throw std::runtime_error("Fingerprint must be set to generate a SDP");
+		throw std::logic_error("Fingerprint must be set to generate a SDP");
+
+	string roleStr;
+	switch (mRole) {
+	case Role::Active:
+		roleStr = "active";
+		break;
+	case Role::Passive:
+		roleStr = "passive";
+		break;
+	default:
+		roleStr = "actpass";
+		break;
+	}
 
-    std::ostringstream sdp;
+	std::ostringstream sdp;
 	sdp << "v=0\n";
 	sdp << "o=- " << mSessionId << " 0 IN IP4 0.0.0.0\n";
 	sdp << "s=-\n";
 	sdp << "t=0 0\n";
-    sdp << "m=application 0 UDP/DTLS/SCTP webrtc-datachannel\n";
-    sdp << "c=IN IP4 0.0.0.0\n";
+	sdp << "m=application 0 UDP/DTLS/SCTP webrtc-datachannel\n";
+	sdp << "c=IN IP4 0.0.0.0\n";
 	sdp << "a=ice-options:trickle\n";
 	sdp << "a=ice-ufrag:" << mIceUfrag << "\n";
 	sdp << "a=ice-pwd:" << mIcePwd << "\n";
 	sdp << "a=mid:" << mMid << "\n";
-	sdp << "a=setup:" << (mRole == Role::Active ? "active" : "passive") << "\n";
+	sdp << "a=setup:" << roleStr << "\n";
 	sdp << "a=dtls-id:1\n";
-	sdp << "a=fingerprint:sha-256 " << *mFingerprint << "\n";
-    sdp << "a=sctp-port:5000\n";
-    // sdp << "a=max-message-size:100000\n";
+	if (mFingerprint)
+		sdp << "a=fingerprint:sha-256 " << *mFingerprint << "\n";
+	if (mSctpPort)
+		sdp << "a=sctp-port:" << *mSctpPort << "\n";
 
 	for (const auto &candidate : mCandidates) {
 		sdp << "a=candidate:" << string(candidate);
@@ -119,4 +138,3 @@ Description::operator string() const {
 }
 
 } // namespace rtc
-

+ 4 - 3
src/description.hpp

@@ -36,10 +36,11 @@ public:
 
 	Role role() const;
 	std::optional<string> fingerprint() const;
+	std::optional<uint16_t> sctpPort() const;
 
-	void setFingerprint(const string &fingerprint);
+	void setFingerprint(string fingerprint);
+	void setSctpPort(uint16_t port);
 	void addCandidate(Candidate candidate);
-	void addCandidate(Candidate &&candidate);
 
 	operator string() const;
 
@@ -49,6 +50,7 @@ private:
 	string mMid;
 	string mIceUfrag, mIcePwd;
 	std::optional<string> mFingerprint;
+	std::optional<uint16_t> mSctpPort;
 
 	std::vector<Candidate> mCandidates;
 };
@@ -56,4 +58,3 @@ private:
 } // namespace rtc
 
 #endif
-

+ 12 - 8
src/peerconnection.cpp

@@ -32,7 +32,8 @@ using std::function;
 using std::shared_ptr;
 
 PeerConnection::PeerConnection(const IceConfiguration &config)
-    : mConfig(config), mCertificate(make_certificate("libdatachannel")), mMid("0") {}
+    : mConfig(config), mCertificate(make_certificate("libdatachannel")), mMid("0"),
+      mSctpPort(5000) {}
 
 PeerConnection::~PeerConnection() {}
 
@@ -43,9 +44,13 @@ const Certificate *PeerConnection::certificate() const { return &mCertificate; }
 void PeerConnection::setRemoteDescription(const string &description) {
 	Description desc(Description::Role::ActPass, description);
 
-	if(auto fingerprint = desc.fingerprint())
+	if (auto fingerprint = desc.fingerprint())
 		mRemoteFingerprint.emplace(*fingerprint);
 
+	if (auto sctpPort = desc.sctpPort()) {
+		mSctpPort = *sctpPort;
+	}
+
 	if (!mIceTransport) {
 		initIceTransport(Description::Role::ActPass);
 		mIceTransport->setRemoteDescription(desc);
@@ -69,7 +74,7 @@ shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
 	auto seed = std::chrono::system_clock::now().time_since_epoch().count();
 	std::default_random_engine generator(seed);
 	std::uniform_int_distribution<uint16_t> uniform;
-	uint16_t stream = uniform(generator);
+	uint16_t stream = 0; // uniform(generator);
 
 	auto channel = std::make_shared<DataChannel>(stream, label, protocol, reliability);
 	mDataChannels.insert(std::make_pair(stream, channel));
@@ -112,7 +117,7 @@ void PeerConnection::initDtlsTransport() {
 
 void PeerConnection::initSctpTransport() {
 	mSctpTransport = std::make_shared<SctpTransport>(
-	    mDtlsTransport, std::bind(&PeerConnection::openDataChannels, this),
+	    mDtlsTransport, mSctpPort, std::bind(&PeerConnection::openDataChannels, this),
 	    std::bind(&PeerConnection::forwardMessage, this, _1));
 }
 
@@ -134,15 +139,15 @@ void PeerConnection::forwardMessage(message_ptr message) {
 }
 
 void PeerConnection::openDataChannels(void) {
-	for (auto it = mDataChannels.begin(); it != mDataChannels.end(); ++it) {
-		it->second->open(mSctpTransport);
-	}
+	for (const auto &[stream, dataChannel] : mDataChannels)
+		dataChannel->open(mSctpTransport);
 }
 
 void PeerConnection::triggerLocalDescription() {
 	if (mLocalDescriptionCallback && mIceTransport) {
 		Description desc{mIceTransport->getLocalDescription()};
 		desc.setFingerprint(mCertificate.fingerprint());
+		desc.setSctpPort(mSctpPort);
 		mLocalDescriptionCallback(string(desc));
 	}
 }
@@ -159,4 +164,3 @@ void PeerConnection::triggerDataChannel(std::shared_ptr<DataChannel> dataChannel
 }
 
 } // namespace rtc
-

+ 1 - 1
src/peerconnection.hpp

@@ -75,6 +75,7 @@ private:
 	Certificate mCertificate;
 	string mMid;
 	std::optional<string> mRemoteFingerprint;
+	uint16_t mSctpPort;
 
 	std::shared_ptr<IceTransport> mIceTransport;
 	std::shared_ptr<DtlsTransport> mDtlsTransport;
@@ -90,4 +91,3 @@ private:
 } // namespace rtc
 
 #endif
-

+ 65 - 31
src/sctptransport.cpp

@@ -18,26 +18,43 @@
 
 #include "sctptransport.hpp"
 
+#include <chrono>
 #include <exception>
 #include <iostream>
 #include <vector>
 
 #include <arpa/inet.h>
 
+using std::shared_ptr;
+
 namespace rtc {
 
-using std::shared_ptr;
+std::mutex SctpTransport::GlobalMutex;
+int SctpTransport::InstancesCount = 0;
+
+void SctpTransport::GlobalInit() {
+	std::unique_lock<std::mutex> lock(GlobalMutex);
+	if (InstancesCount++ == 0) {
+		usrsctp_init(0, &SctpTransport::WriteCallback, nullptr);
+		usrsctp_sysctl_set_sctp_ecn_enable(0);
+	}
+}
 
-SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, ready_callback ready,
+void SctpTransport::GlobalCleanup() {
+	std::unique_lock<std::mutex> lock(GlobalMutex);
+	if (InstancesCount-- == 0) {
+		usrsctp_finish();
+	}
+}
+
+SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, ready_callback ready,
                              message_callback recv)
-    : Transport(lower), mReadyCallback(std::move(ready)), mLocalPort(5000), mRemotePort(5000) {
+    : Transport(lower), mReadyCallback(std::move(ready)), mPort(port) {
 
 	onRecv(recv);
 
-	usrsctp_init(0, &SctpTransport::WriteCallback, nullptr);
-	usrsctp_sysctl_set_sctp_ecn_enable(0);
+	GlobalInit();
 	usrsctp_register_address(this);
-
 	mSock = usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, &SctpTransport::ReadCallback,
 	                       nullptr, 0, this);
 	if (!mSock)
@@ -52,7 +69,8 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, ready_callback re
 
 	struct sctp_paddrparams spp = {};
 	spp.spp_flags = SPP_PMTUD_DISABLE;
-	spp.spp_pathmtu = 1200; // TODO: MTU
+	spp.spp_pathmtu = 1200; // Max safe value recommended by RFC 8261
+	                        // See https://tools.ietf.org/html/rfc8261#section-5
 	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_PEER_ADDR_PARAMS, &spp, sizeof(spp)))
 		throw std::runtime_error("Could not set socket option SCTP_PEER_ADDR_PARAMS, errno=" +
 		                         std::to_string(errno));
@@ -86,7 +104,7 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, ready_callback re
 
 	struct sockaddr_conn sconn = {};
 	sconn.sconn_family = AF_CONN;
-	sconn.sconn_port = htons(mLocalPort);
+	sconn.sconn_port = htons(mPort);
 	sconn.sconn_addr = this;
 #ifdef HAVE_SCONN_LEN
 	sconn.sconn_len = sizeof(sconn);
@@ -109,7 +127,7 @@ SctpTransport::~SctpTransport() {
 	}
 
 	usrsctp_deregister_address(this);
-	usrsctp_finish();
+	GlobalCleanup();
 }
 
 bool SctpTransport::isReady() const { return mIsReady; }
@@ -153,8 +171,14 @@ int SctpTransport::WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t
 }
 
 int SctpTransport::handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df) {
-	const byte *b = reinterpret_cast<const byte *>(data);
+	byte *b = reinterpret_cast<byte *>(data);
 	outgoing(make_message(b, b + len));
+
+	if (!mConnectDataSent) {
+		std::unique_lock<std::mutex> lock(mConnectMutex);
+		mConnectDataSent = true;
+		mConnectCondition.notify_all();
+	}
 	return 0; // success
 }
 
@@ -201,22 +225,21 @@ void SctpTransport::processData(const byte *data, size_t len, uint16_t sid, Payl
 }
 
 bool SctpTransport::send(message_ptr message) {
-  const Reliability reliability =
-      message->reliability ? *message->reliability : Reliability();
-
-  struct sctp_sendv_spa spa = {};
-
-  uint32_t ppid;
-  switch (message->type) {
-  case Message::String:
-    ppid = message->empty() ? PPID_STRING : PPID_STRING_EMPTY;
-    break;
-  case Message::Binary:
-    ppid = message->empty() ? PPID_BINARY : PPID_BINARY_EMPTY;
-    break;
-  default:
-    ppid = PPID_CONTROL;
-    break;
+	const Reliability reliability = message->reliability ? *message->reliability : Reliability();
+
+	struct sctp_sendv_spa spa = {};
+
+	uint32_t ppid;
+	switch (message->type) {
+	case Message::String:
+		ppid = message->empty() ? PPID_STRING : PPID_STRING_EMPTY;
+		break;
+	case Message::Binary:
+		ppid = message->empty() ? PPID_BINARY : PPID_BINARY_EMPTY;
+		break;
+	default:
+		ppid = PPID_CONTROL;
+		break;
 	}
 
 	// set sndinfo
@@ -259,7 +282,7 @@ bool SctpTransport::send(message_ptr message) {
 void SctpTransport::reset(unsigned int stream) {
 	using srs_t = struct sctp_reset_streams;
 	const size_t len = sizeof(srs_t) + sizeof(uint16_t);
-	std::byte buffer[len] = {};
+	byte buffer[len] = {};
 	srs_t &srs = *reinterpret_cast<srs_t *>(buffer);
 	srs.srs_flags = SCTP_STREAM_RESET_OUTGOING;
 	srs.srs_number_streams = 1;
@@ -268,27 +291,38 @@ void SctpTransport::reset(unsigned int stream) {
 }
 
 void SctpTransport::incoming(message_ptr message) {
+	// There could be a race condition here where we receive the remote INIT before the thread in
+	// usrsctp_connect sends the local one, which would result in the connection being aborted.
+	// Therefore, we need to wait for data to be sent on our side (i.e. the local INIT) before
+	// proceeding.
+	if (!mConnectDataSent) {
+		std::unique_lock<std::mutex> lock(mConnectMutex);
+		mConnectCondition.wait(lock, [this] { return mConnectDataSent == true; });
+	}
+
 	usrsctp_conninput(this, message->data(), message->size(), 0);
 }
 
 void SctpTransport::runConnect() {
 	struct sockaddr_conn sconn = {};
 	sconn.sconn_family = AF_CONN;
-	sconn.sconn_port = htons(mRemotePort);
+	sconn.sconn_port = htons(mPort);
 	sconn.sconn_addr = this;
 #ifdef HAVE_SCONN_LEN
 	sconn.sconn_len = sizeof(sconn);
 #endif
 
-	// Blocks until connection succeeds/fails
+	// According to the IETF draft, both endpoints must initiate the SCTP association, in a
+	// simultaneous-open manner, irrelevent of the SDP setup role.
+	// See https://tools.ietf.org/html/draft-ietf-mmusic-sctp-sdp-26#section-9.3
 	if (usrsctp_connect(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn)) != 0) {
 		std::cerr << "SCTP connection failed, errno=" << errno << std::endl;
 		mStopping = true;
 		return;
 	}
 
-    mIsReady = true;
-    mReadyCallback();
+	mIsReady = true;
+	mReadyCallback();
 }
 
 } // namespace rtc

+ 15 - 9
src/sctptransport.hpp

@@ -36,12 +36,13 @@ class SctpTransport : public Transport {
 public:
 	using ready_callback = std::function<void(void)>;
 
-	SctpTransport(std::shared_ptr<Transport> lower, ready_callback ready, message_callback recv);
+	SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, ready_callback ready,
+	              message_callback recv);
 	~SctpTransport();
 
-        bool isReady() const;
+	bool isReady() const;
 
-        bool send(message_ptr message);
+	bool send(message_ptr message);
 	void reset(unsigned int stream);
 
 private:
@@ -67,22 +68,27 @@ private:
 	ready_callback mReadyCallback;
 
 	struct socket *mSock;
-	uint16_t mLocalPort;
-	uint16_t mRemotePort;
+	uint16_t mPort;
 
-        std::thread mConnectThread;
+	std::thread mConnectThread;
 	std::atomic<bool> mStopping = false;
-        std::atomic<bool> mIsReady = false;
+	std::atomic<bool> mIsReady = false;
 
-        std::mutex mConnectMutex;
+	std::mutex mConnectMutex;
 	std::condition_variable mConnectCondition;
+	std::atomic<bool> mConnectDataSent = false;
 
 	static int WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df);
 	static int ReadCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
 	                        struct sctp_rcvinfo recv_info, int flags, void *user_data);
+
+	void GlobalInit();
+	void GlobalCleanup();
+
+	static std::mutex GlobalMutex;
+	static int InstancesCount;
 };
 
 } // namespace rtc
 
 #endif
-