Browse Source

Revised synchronization

Paul-Louis Ageneau 5 years ago
parent
commit
e5a19f85ed

+ 2 - 2
include/rtc/include.hpp

@@ -57,13 +57,13 @@ public:
 	~synchronized_callback() { *this = nullptr; }
 
 	synchronized_callback &operator=(std::function<void(P...)> func) {
-		std::lock_guard<std::recursive_mutex> lock(mutex);
+		std::lock_guard lock(mutex);
 		callback = func;
 		return *this;
 	}
 
 	void operator()(P... args) const {
-		std::lock_guard<std::recursive_mutex> lock(mutex);
+		std::lock_guard lock(mutex);
 		if (callback)
 			callback(args...);
 	}

+ 6 - 5
include/rtc/peerconnection.hpp

@@ -31,6 +31,7 @@
 #include <atomic>
 #include <functional>
 #include <list>
+#include <mutex>
 #include <thread>
 #include <unordered_map>
 
@@ -83,9 +84,9 @@ public:
 	void onGatheringStateChange(std::function<void(GatheringState state)> callback);
 
 private:
-	void initIceTransport(Description::Role role);
-	void initDtlsTransport();
-	void initSctpTransport();
+	std::shared_ptr<IceTransport> initIceTransport(Description::Role role);
+	std::shared_ptr<DtlsTransport> initDtlsTransport();
+	std::shared_ptr<SctpTransport> initSctpTransport();
 
 	bool checkFingerprint(const std::string &fingerprint) const;
 	void forwardMessage(message_ptr message);
@@ -103,8 +104,8 @@ private:
 	const Configuration mConfig;
 	const std::shared_ptr<Certificate> mCertificate;
 
-	std::optional<Description> mLocalDescription;
-	std::optional<Description> mRemoteDescription;
+	std::optional<Description> mLocalDescription, mRemoteDescription;
+	mutable std::recursive_mutex mLocalDescriptionMutex, mRemoteDescriptionMutex;
 
 	std::shared_ptr<IceTransport> mIceTransport;
 	std::shared_ptr<DtlsTransport> mDtlsTransport;

+ 9 - 9
include/rtc/queue.hpp

@@ -67,31 +67,31 @@ Queue<T>::Queue(size_t limit, amount_function func) : mLimit(limit), mAmount(0)
 template <typename T> Queue<T>::~Queue() { stop(); }
 
 template <typename T> void Queue<T>::stop() {
-	std::lock_guard<std::mutex> lock(mMutex);
+	std::lock_guard lock(mMutex);
 	mStopping = true;
 	mPopCondition.notify_all();
 	mPushCondition.notify_all();
 }
 
 template <typename T> bool Queue<T>::empty() const {
-	std::lock_guard<std::mutex> lock(mMutex);
+	std::lock_guard lock(mMutex);
 	return mQueue.empty();
 }
 
 template <typename T> size_t Queue<T>::size() const {
-	std::lock_guard<std::mutex> lock(mMutex);
+	std::lock_guard lock(mMutex);
 	return mQueue.size();
 }
 
 template <typename T> size_t Queue<T>::amount() const {
-	std::lock_guard<std::mutex> lock(mMutex);
+	std::lock_guard lock(mMutex);
 	return mAmount;
 }
 
 template <typename T> void Queue<T>::push(const T &element) { push(T{element}); }
 
 template <typename T> void Queue<T>::push(T &&element) {
-	std::unique_lock<std::mutex> lock(mMutex);
+	std::unique_lock lock(mMutex);
 	mPushCondition.wait(lock, [this]() { return !mLimit || mQueue.size() < mLimit || mStopping; });
 	if (!mStopping) {
 		mAmount += mAmountFunction(element);
@@ -101,7 +101,7 @@ template <typename T> void Queue<T>::push(T &&element) {
 }
 
 template <typename T> std::optional<T> Queue<T>::pop() {
-	std::unique_lock<std::mutex> lock(mMutex);
+	std::unique_lock lock(mMutex);
 	mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
 	if (!mQueue.empty()) {
 		mAmount -= mAmountFunction(mQueue.front());
@@ -114,7 +114,7 @@ template <typename T> std::optional<T> Queue<T>::pop() {
 }
 
 template <typename T> std::optional<T> Queue<T>::peek() {
-	std::unique_lock<std::mutex> lock(mMutex);
+	std::unique_lock lock(mMutex);
 	if (!mQueue.empty()) {
 		return std::optional<T>{mQueue.front()};
 	} else {
@@ -123,12 +123,12 @@ template <typename T> std::optional<T> Queue<T>::peek() {
 }
 
 template <typename T> void Queue<T>::wait() {
-	std::unique_lock<std::mutex> lock(mMutex);
+	std::unique_lock lock(mMutex);
 	mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
 }
 
 template <typename T> void Queue<T>::wait(const std::chrono::milliseconds &duration) {
-	std::unique_lock<std::mutex> lock(mMutex);
+	std::unique_lock lock(mMutex);
 	mPopCondition.wait_for(lock, duration, [this]() { return !mQueue.empty() || mStopping; });
 }
 

+ 2 - 2
src/certificate.cpp

@@ -145,7 +145,7 @@ shared_ptr<Certificate> make_certificate(const string &commonName) {
 	static std::unordered_map<string, shared_ptr<Certificate>> cache;
 	static std::mutex cacheMutex;
 
-	std::lock_guard<std::mutex> lock(cacheMutex);
+	std::lock_guard lock(cacheMutex);
 	if (auto it = cache.find(commonName); it != cache.end())
 		return it->second;
 
@@ -241,7 +241,7 @@ shared_ptr<Certificate> make_certificate(const string &commonName) {
 	static std::unordered_map<string, shared_ptr<Certificate>> cache;
 	static std::mutex cacheMutex;
 
-	std::lock_guard<std::mutex> lock(cacheMutex);
+	std::lock_guard lock(cacheMutex);
 	if (auto it = cache.find(commonName); it != cache.end())
 		return it->second;
 

+ 13 - 5
src/dtlstransport.cpp

@@ -85,6 +85,8 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 }
 
 DtlsTransport::~DtlsTransport() {
+	stop();
+
 	gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
 	gnutls_deinit(mSession);
 }
@@ -94,8 +96,10 @@ DtlsTransport::State DtlsTransport::state() const { return mState; }
 void DtlsTransport::stop() {
 	Transport::stop();
 
-	mIncomingQueue.stop();
-	mRecvThread.join();
+	if (mRecvThread.joinable()) {
+		mIncomingQueue.stop();
+		mRecvThread.join();
+	}
 }
 
 bool DtlsTransport::send(message_ptr message) {
@@ -293,7 +297,7 @@ int DtlsTransport::TransportExIndex = -1;
 std::mutex DtlsTransport::GlobalMutex;
 
 void DtlsTransport::GlobalInit() {
-	std::lock_guard<std::mutex> lock(GlobalMutex);
+	std::lock_guard lock(GlobalMutex);
 	if (TransportExIndex < 0) {
 		TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
 	}
@@ -358,6 +362,8 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 }
 
 DtlsTransport::~DtlsTransport() {
+	stop();
+
 	SSL_shutdown(mSsl);
 	SSL_free(mSsl);
 	SSL_CTX_free(mCtx);
@@ -366,8 +372,10 @@ DtlsTransport::~DtlsTransport() {
 void DtlsTransport::stop() {
 	Transport::stop();
 
-	mIncomingQueue.stop();
-	mRecvThread.join();
+	if (mRecvThread.joinable()) {
+		mIncomingQueue.stop();
+		mRecvThread.join();
+	}
 }
 
 DtlsTransport::State DtlsTransport::state() const { return mState; }

+ 2 - 2
src/dtlstransport.hpp

@@ -55,10 +55,10 @@ public:
 	State state() const;
 
 	void stop() override;
-	bool send(message_ptr message); // false if dropped
+	bool send(message_ptr message) override; // false if dropped
 
 private:
-	void incoming(message_ptr message);
+	void incoming(message_ptr message) override;
 	void changeState(State state);
 	void runRecvLoop();
 

+ 5 - 3
src/icetransport.cpp

@@ -130,11 +130,13 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role,
 	                       RecvCallback, this);
 }
 
-IceTransport::~IceTransport() {}
+IceTransport::~IceTransport() { stop(); }
 
 void IceTransport::stop() {
-	g_main_loop_quit(mMainLoop.get());
-	mMainLoopThread.join();
+	if (mMainLoopThread.joinable()) {
+		g_main_loop_quit(mMainLoop.get());
+		mMainLoopThread.join();
+	}
 }
 
 Description::Role IceTransport::role() const { return mRole; }

+ 2 - 2
src/icetransport.hpp

@@ -71,9 +71,9 @@ public:
 	bool send(message_ptr message) override; // false if dropped
 
 private:
-	void incoming(message_ptr message);
+	void incoming(message_ptr message) override;
 	void incoming(const byte *data, int size);
-	void outgoing(message_ptr message);
+	void outgoing(message_ptr message) override;
 
 	void changeState(State state);
 	void changeGatheringState(GatheringState state);

+ 79 - 44
src/peerconnection.cpp

@@ -20,6 +20,7 @@
 #include "certificate.hpp"
 #include "dtlstransport.hpp"
 #include "icetransport.hpp"
+#include "include.hpp"
 #include "sctptransport.hpp"
 
 #include <iostream>
@@ -37,12 +38,12 @@ PeerConnection::PeerConnection(const Configuration &config)
     : mConfig(config), mCertificate(make_certificate("libdatachannel")), mState(State::New) {}
 
 PeerConnection::~PeerConnection() {
-	if (mIceTransport)
-		mIceTransport->stop();
-	if (mDtlsTransport)
-		mDtlsTransport->stop();
-	if (mSctpTransport)
-		mSctpTransport->stop();
+	if (auto transport = std::atomic_load(&mIceTransport))
+		transport->stop();
+	if (auto transport = std::atomic_load(&mDtlsTransport))
+		transport->stop();
+	if (auto transport = std::atomic_load(&mSctpTransport))
+		transport->stop();
 
 	mSctpTransport.reset();
 	mDtlsTransport.reset();
@@ -55,26 +56,36 @@ PeerConnection::State PeerConnection::state() const { return mState; }
 
 PeerConnection::GatheringState PeerConnection::gatheringState() const { return mGatheringState; }
 
-std::optional<Description> PeerConnection::localDescription() const { return mLocalDescription; }
+std::optional<Description> PeerConnection::localDescription() const {
+	std::lock_guard lock(mLocalDescriptionMutex);
+	return mLocalDescription;
+}
 
-std::optional<Description> PeerConnection::remoteDescription() const { return mRemoteDescription; }
+std::optional<Description> PeerConnection::remoteDescription() const {
+	std::lock_guard lock(mRemoteDescriptionMutex);
+	return mRemoteDescription;
+}
 
 void PeerConnection::setRemoteDescription(Description description) {
+	std::lock_guard lock(mRemoteDescriptionMutex);
+
 	auto remoteCandidates = description.extractCandidates();
 	mRemoteDescription.emplace(std::move(description));
 
-	if (!mIceTransport)
-		initIceTransport(Description::Role::ActPass);
+	auto iceTransport = std::atomic_load(&mIceTransport);
+	if (!iceTransport)
+		iceTransport = initIceTransport(Description::Role::ActPass);
 
-	mIceTransport->setRemoteDescription(*mRemoteDescription);
+	iceTransport->setRemoteDescription(*mRemoteDescription);
 
 	if (mRemoteDescription->type() == Description::Type::Offer) {
 		// This is an offer and we are the answerer.
-		processLocalDescription(mIceTransport->getLocalDescription(Description::Type::Answer));
-		mIceTransport->gatherLocalCandidates();
+		processLocalDescription(iceTransport->getLocalDescription(Description::Type::Answer));
+		iceTransport->gatherLocalCandidates();
 	} else {
 		// This is an answer and we are the offerer.
-		if (!mSctpTransport && mIceTransport->role() == Description::Role::Active) {
+		auto sctpTransport = std::atomic_load(&mSctpTransport);
+		if (!sctpTransport && iceTransport->role() == Description::Role::Active) {
 			// Since we assumed passive role during DataChannel creation, we need to shift the
 			// stream numbers by one to shift them from odd to even.
 			decltype(mDataChannels) newDataChannels;
@@ -92,16 +103,19 @@ void PeerConnection::setRemoteDescription(Description description) {
 }
 
 void PeerConnection::addRemoteCandidate(Candidate candidate) {
-	if (!mRemoteDescription || !mIceTransport)
+	std::lock_guard lock(mRemoteDescriptionMutex);
+
+	auto iceTransport = std::atomic_load(&mIceTransport);
+	if (!mRemoteDescription || !iceTransport)
 		throw std::logic_error("Remote candidate set without remote description");
 
 	mRemoteDescription->addCandidate(candidate);
 
 	if (candidate.resolve(Candidate::ResolveMode::Simple)) {
-		mIceTransport->addRemoteCandidate(candidate);
+		iceTransport->addRemoteCandidate(candidate);
 	} else {
 		// OK, we might need a lookup, do it asynchronously
-		weak_ptr<IceTransport> weakIceTransport{mIceTransport};
+		weak_ptr<IceTransport> weakIceTransport{iceTransport};
 		std::thread t([weakIceTransport, candidate]() mutable {
 			if (candidate.resolve(Candidate::ResolveMode::Lookup))
 				if (auto iceTransport = weakIceTransport.lock())
@@ -112,11 +126,13 @@ void PeerConnection::addRemoteCandidate(Candidate candidate) {
 }
 
 std::optional<string> PeerConnection::localAddress() const {
-	return mIceTransport ? mIceTransport->getLocalAddress() : nullopt;
+	auto iceTransport = std::atomic_load(&mIceTransport);
+	return iceTransport ? iceTransport->getLocalAddress() : nullopt;
 }
 
 std::optional<string> PeerConnection::remoteAddress() const {
-	return mIceTransport ? mIceTransport->getRemoteAddress() : nullopt;
+	auto iceTransport = std::atomic_load(&mIceTransport);
+	return iceTransport ? iceTransport->getRemoteAddress() : nullopt;
 }
 
 shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
@@ -126,7 +142,8 @@ shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
 	// setup:passive. [...] Thus, setup:active is RECOMMENDED.
 	// See https://tools.ietf.org/html/rfc5763#section-5
 	// Therefore, we assume passive role when we are the offerer.
-	auto role = mIceTransport ? mIceTransport->role() : Description::Role::Passive;
+	auto iceTransport = std::atomic_load(&mIceTransport);
+	auto role = iceTransport ? iceTransport->role() : Description::Role::Passive;
 
 	// The active side must use streams with even identifiers, whereas the passive side must use
 	// streams with odd identifiers.
@@ -142,15 +159,17 @@ shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
 	    std::make_shared<DataChannel>(shared_from_this(), stream, label, protocol, reliability);
 	mDataChannels.insert(std::make_pair(stream, channel));
 
-	if (!mIceTransport) {
+	if (!iceTransport) {
 		// RFC 5763: The endpoint that is the offerer MUST use the setup attribute value of
 		// setup:actpass.
 		// See https://tools.ietf.org/html/rfc5763#section-5
-		initIceTransport(Description::Role::ActPass);
-		processLocalDescription(mIceTransport->getLocalDescription(Description::Type::Offer));
-		mIceTransport->gatherLocalCandidates();
-	} else if (mSctpTransport && mSctpTransport->state() == SctpTransport::State::Connected) {
-		channel->open(mSctpTransport);
+		iceTransport = initIceTransport(Description::Role::ActPass);
+		processLocalDescription(iceTransport->getLocalDescription(Description::Type::Offer));
+		iceTransport->gatherLocalCandidates();
+	} else {
+		if (auto transport = std::atomic_load(&mSctpTransport))
+			if (transport->state() == SctpTransport::State::Connected)
+				channel->open(transport);
 	}
 	return channel;
 }
@@ -177,8 +196,8 @@ void PeerConnection::onGatheringStateChange(std::function<void(GatheringState st
 	mGatheringStateChangeCallback = callback;
 }
 
-void PeerConnection::initIceTransport(Description::Role role) {
-	mIceTransport = std::make_shared<IceTransport>(
+shared_ptr<IceTransport> PeerConnection::initIceTransport(Description::Role role) {
+	auto transport = std::make_shared<IceTransport>(
 	    mConfig, role, std::bind(&PeerConnection::processLocalCandidate, this, _1),
 	    [this](IceTransport::State state) {
 		    switch (state) {
@@ -211,11 +230,14 @@ void PeerConnection::initIceTransport(Description::Role role) {
 			    break;
 		    }
 	    });
+	std::atomic_store(&mIceTransport, transport);
+	return transport;
 }
 
-void PeerConnection::initDtlsTransport() {
-	mDtlsTransport = std::make_shared<DtlsTransport>(
-	    mIceTransport, mCertificate, std::bind(&PeerConnection::checkFingerprint, this, _1),
+shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
+	auto lower = std::atomic_load(&mIceTransport);
+	auto transport = std::make_shared<DtlsTransport>(
+	    lower, mCertificate, std::bind(&PeerConnection::checkFingerprint, this, _1),
 	    [this](DtlsTransport::State state) {
 		    switch (state) {
 		    case DtlsTransport::State::Connected:
@@ -229,12 +251,15 @@ void PeerConnection::initDtlsTransport() {
 			    break;
 		    }
 	    });
+	std::atomic_store(&mDtlsTransport, transport);
+	return transport;
 }
 
-void PeerConnection::initSctpTransport() {
-	uint16_t sctpPort = mRemoteDescription->sctpPort().value_or(DEFAULT_SCTP_PORT);
-	mSctpTransport = std::make_shared<SctpTransport>(
-	    mDtlsTransport, sctpPort, std::bind(&PeerConnection::forwardMessage, this, _1),
+shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
+	uint16_t sctpPort = remoteDescription()->sctpPort().value_or(DEFAULT_SCTP_PORT);
+	auto lower = std::atomic_load(&mDtlsTransport);
+	auto transport = std::make_shared<SctpTransport>(
+	    lower, sctpPort, std::bind(&PeerConnection::forwardMessage, this, _1),
 	    std::bind(&PeerConnection::forwardBufferedAmount, this, _1, _2),
 	    [this](SctpTransport::State state) {
 		    switch (state) {
@@ -253,9 +278,12 @@ void PeerConnection::initSctpTransport() {
 			    break;
 		    }
 	    });
+	std::atomic_store(&mSctpTransport, transport);
+	return transport;
 }
 
 bool PeerConnection::checkFingerprint(const std::string &fingerprint) const {
+	std::lock_guard lock(mRemoteDescriptionMutex);
 	if (auto expectedFingerprint =
 	        mRemoteDescription ? mRemoteDescription->fingerprint() : nullopt) {
 		return *expectedFingerprint == fingerprint;
@@ -264,9 +292,6 @@ bool PeerConnection::checkFingerprint(const std::string &fingerprint) const {
 }
 
 void PeerConnection::forwardMessage(message_ptr message) {
-	if (!mIceTransport || !mSctpTransport)
-		throw std::logic_error("Got a DataChannel message without transport");
-
 	if (!message) {
 		closeDataChannels();
 		return;
@@ -281,19 +306,24 @@ void PeerConnection::forwardMessage(message_ptr message) {
 		}
 	}
 
+	auto iceTransport = std::atomic_load(&mIceTransport);
+	auto sctpTransport = std::atomic_load(&mSctpTransport);
+	if (!iceTransport || !sctpTransport)
+		return;
+
 	if (!channel) {
 		const byte dataChannelOpenMessage{0x03};
-		unsigned int remoteParity = (mIceTransport->role() == Description::Role::Active) ? 1 : 0;
+		unsigned int remoteParity = (iceTransport->role() == Description::Role::Active) ? 1 : 0;
 		if (message->type == Message::Control && *message->data() == dataChannelOpenMessage &&
 		    message->stream % 2 == remoteParity) {
 			channel =
-			    std::make_shared<DataChannel>(shared_from_this(), mSctpTransport, message->stream);
+			    std::make_shared<DataChannel>(shared_from_this(), sctpTransport, message->stream);
 			channel->onOpen(std::bind(&PeerConnection::triggerDataChannel, this,
 			                          weak_ptr<DataChannel>{channel}));
 			mDataChannels.insert(std::make_pair(message->stream, channel));
 		} else {
 			// Invalid, close the DataChannel by resetting the stream
-			mSctpTransport->reset(message->stream);
+			sctpTransport->reset(message->stream);
 			return;
 		}
 	}
@@ -330,16 +360,20 @@ void PeerConnection::iterateDataChannels(
 }
 
 void PeerConnection::openDataChannels() {
-	iterateDataChannels([this](shared_ptr<DataChannel> channel) { channel->open(mSctpTransport); });
+	if (auto transport = std::atomic_load(&mSctpTransport))
+		iterateDataChannels([&](shared_ptr<DataChannel> channel) { channel->open(transport); });
 }
 
 void PeerConnection::closeDataChannels() {
-	iterateDataChannels([](shared_ptr<DataChannel> channel) { channel->close(); });
+	iterateDataChannels([&](shared_ptr<DataChannel> channel) { channel->close(); });
 }
 
 void PeerConnection::processLocalDescription(Description description) {
-	auto remoteSctpPort = mRemoteDescription ? mRemoteDescription->sctpPort() : nullopt;
+	std::optional<uint16_t> remoteSctpPort;
+	if (auto remote = remoteDescription())
+		remoteSctpPort = remote->sctpPort();
 
+	std::lock_guard lock(mLocalDescriptionMutex);
 	mLocalDescription.emplace(std::move(description));
 	mLocalDescription->setFingerprint(mCertificate->fingerprint());
 	mLocalDescription->setSctpPort(remoteSctpPort.value_or(DEFAULT_SCTP_PORT));
@@ -349,6 +383,7 @@ void PeerConnection::processLocalDescription(Description description) {
 }
 
 void PeerConnection::processLocalCandidate(Candidate candidate) {
+	std::lock_guard lock(mLocalDescriptionMutex);
 	if (!mLocalDescription)
 		throw std::logic_error("Got a local candidate without local description");
 

+ 16 - 18
src/sctptransport.cpp

@@ -33,7 +33,7 @@ std::mutex SctpTransport::GlobalMutex;
 int SctpTransport::InstancesCount = 0;
 
 void SctpTransport::GlobalInit() {
-	std::unique_lock<std::mutex> lock(GlobalMutex);
+	std::lock_guard lock(GlobalMutex);
 	if (InstancesCount++ == 0) {
 		usrsctp_init(0, &SctpTransport::WriteCallback, nullptr);
 		usrsctp_sysctl_set_sctp_ecn_enable(0);
@@ -41,7 +41,7 @@ void SctpTransport::GlobalInit() {
 }
 
 void SctpTransport::GlobalCleanup() {
-	std::unique_lock<std::mutex> lock(GlobalMutex);
+	std::lock_guard lock(GlobalMutex);
 	if (--InstancesCount == 0) {
 		usrsctp_finish();
 	}
@@ -143,6 +143,8 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
 }
 
 SctpTransport::~SctpTransport() {
+	stop();
+
 	if (mSock) {
 		usrsctp_shutdown(mSock, SHUT_RDWR);
 		usrsctp_close(mSock);
@@ -156,15 +158,14 @@ SctpTransport::State SctpTransport::state() const { return mState; }
 
 void SctpTransport::stop() {
 	Transport::stop();
+	onRecv(nullptr);
 
 	mSendQueue.stop();
 
 	// Unblock incoming
-	if (!mConnectDataSent) {
-		std::unique_lock<std::mutex> lock(mConnectMutex);
-		mConnectDataSent = true;
-		mConnectCondition.notify_all();
-	}
+	std::unique_lock<std::mutex> lock(mConnectMutex);
+	mConnectDataSent = true;
+	mConnectCondition.notify_all();
 }
 
 void SctpTransport::connect() {
@@ -190,7 +191,7 @@ void SctpTransport::connect() {
 }
 
 bool SctpTransport::send(message_ptr message) {
-	std::lock_guard<std::mutex> lock(mSendMutex);
+	std::lock_guard lock(mSendMutex);
 
 	if (!message)
 		return mSendQueue.empty();
@@ -225,8 +226,8 @@ void SctpTransport::incoming(message_ptr message) {
 	// There could be a race condition here where we receive the remote INIT before the local one is
 	// sent, which would result in the connection being aborted. Therefore, we need to wait for data
 	// to be sent on our side (i.e. the local INIT) before proceeding.
-	if (!mConnectDataSent) {
-		std::unique_lock<std::mutex> lock(mConnectMutex);
+	{
+		std::unique_lock lock(mConnectMutex);
 		mConnectCondition.wait(lock, [this]() -> bool { return mConnectDataSent; });
 	}
 
@@ -361,7 +362,7 @@ int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, co
 
 int SctpTransport::handleSend(size_t free) {
 	try {
-		std::lock_guard<std::mutex> lock(mSendMutex);
+		std::lock_guard lock(mSendMutex);
 		trySendQueue();
 	} catch (const std::exception &e) {
 		std::cerr << "SCTP send: " << e.what() << std::endl;
@@ -374,11 +375,9 @@ int SctpTransport::handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_
 	try {
 		outgoing(make_message(data, data + len));
 
-		if (!mConnectDataSent) {
-			std::unique_lock<std::mutex> lock(mConnectMutex);
-			mConnectDataSent = true;
-			mConnectCondition.notify_all();
-		}
+		std::unique_lock lock(mConnectMutex);
+		mConnectDataSent = true;
+		mConnectCondition.notify_all();
 	} catch (const std::exception &e) {
 		std::cerr << "SCTP write: " << e.what() << std::endl;
 		return -1;
@@ -453,7 +452,6 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
 	switch (notify->sn_header.sn_type) {
 	case SCTP_ASSOC_CHANGE: {
 		const struct sctp_assoc_change &assoc_change = notify->sn_assoc_change;
-		std::unique_lock<std::mutex> lock(mConnectMutex);
 		if (assoc_change.sac_state == SCTP_COMM_UP) {
 			changeState(State::Connected);
 		} else {
@@ -468,7 +466,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
 	case SCTP_SENDER_DRY_EVENT: {
 		// It not should be necessary since the send callback should have been called already,
 		// but to be sure, let's try to send now.
-		std::lock_guard<std::mutex> lock(mSendMutex);
+		std::lock_guard lock(mSendMutex);
 		trySendQueue();
 	}
 	case SCTP_STREAM_RESET_EVENT: {

+ 2 - 3
src/sctptransport.hpp

@@ -68,7 +68,7 @@ private:
 	};
 
 	void connect();
-	void incoming(message_ptr message);
+	void incoming(message_ptr message) override;
 	void changeState(State state);
 
 	bool trySendQueue();
@@ -93,8 +93,7 @@ private:
 
 	std::mutex mConnectMutex;
 	std::condition_variable mConnectCondition;
-	std::atomic<bool> mConnectDataSent = false;
-	std::atomic<bool> mStopping = false;
+	bool mConnectDataSent = false;
 
 	state_callback mStateChangeCallback;
 	std::atomic<State> mState;