Browse Source

Fixed transport synchronization on destruction

Paul-Louis Ageneau 5 years ago
parent
commit
de5aff68e6

+ 1 - 1
include/rtc/datachannel.hpp

@@ -71,7 +71,7 @@ private:
 	void incoming(message_ptr message);
 	void incoming(message_ptr message);
 	void processOpenMessage(message_ptr message);
 	void processOpenMessage(message_ptr message);
 
 
-	const std::shared_ptr<PeerConnection> mPeerConnection;
+	std::shared_ptr<PeerConnection> mPeerConnection;
 	std::shared_ptr<SctpTransport> mSctpTransport;
 	std::shared_ptr<SctpTransport> mSctpTransport;
 
 
 	unsigned int mStream;
 	unsigned int mStream;

+ 5 - 6
include/rtc/peerconnection.hpp

@@ -87,17 +87,16 @@ private:
 	void initDtlsTransport();
 	void initDtlsTransport();
 	void initSctpTransport();
 	void initSctpTransport();
 
 
-	bool checkFingerprint(std::weak_ptr<PeerConnection> weak_this, const std::string &fingerprint) const;
-	void forwardMessage(std::weak_ptr<PeerConnection> weak_this, message_ptr message);
-	void forwardBufferedAmount(std::weak_ptr<PeerConnection> weak_this, uint16_t stream,
-	                           size_t amount);
+	bool checkFingerprint(const std::string &fingerprint) const;
+	void forwardMessage(message_ptr message);
+	void forwardBufferedAmount(uint16_t stream, size_t amount);
 	void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
 	void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
 	void openDataChannels();
 	void openDataChannels();
 	void closeDataChannels();
 	void closeDataChannels();
 
 
 	void processLocalDescription(Description description);
 	void processLocalDescription(Description description);
-	void processLocalCandidate(std::weak_ptr<PeerConnection> weak_this, Candidate candidate);
-	void triggerDataChannel(std::weak_ptr<PeerConnection> weak_this, std::weak_ptr<DataChannel> weakDataChannel);
+	void processLocalCandidate(Candidate candidate);
+	void triggerDataChannel(std::weak_ptr<DataChannel> weakDataChannel);
 	void changeState(State state);
 	void changeState(State state);
 	void changeGatheringState(GatheringState state);
 	void changeGatheringState(GatheringState state);
 
 

+ 4 - 0
src/datachannel.cpp

@@ -81,6 +81,10 @@ void DataChannel::close() {
 		if (mSctpTransport)
 		if (mSctpTransport)
 			mSctpTransport->reset(mStream);
 			mSctpTransport->reset(mStream);
 	}
 	}
+
+	// Reset mSctpTransport first so SctpTransport is never alive without PeerConnection
+	mSctpTransport.reset();
+	mPeerConnection.reset();
 }
 }
 
 
 bool DataChannel::send(const std::variant<binary, string> &data) {
 bool DataChannel::send(const std::variant<binary, string> &data) {

+ 6 - 6
src/dtlstransport.cpp

@@ -85,10 +85,10 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 }
 }
 
 
 DtlsTransport::~DtlsTransport() {
 DtlsTransport::~DtlsTransport() {
-	mIncomingQueue.stop();
+	resetLower();
 
 
-	if (mRecvThread.joinable())
-		mRecvThread.join();
+	mIncomingQueue.stop();
+	mRecvThread.join();
 
 
 	gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
 	gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
 	gnutls_deinit(mSession);
 	gnutls_deinit(mSession);
@@ -356,10 +356,10 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 }
 }
 
 
 DtlsTransport::~DtlsTransport() {
 DtlsTransport::~DtlsTransport() {
-	mIncomingQueue.stop();
+	resetLower();
 
 
-	if (mRecvThread.joinable())
-		mRecvThread.join();
+	mIncomingQueue.stop();
+	mRecvThread.join();
 
 
 	SSL_shutdown(mSsl);
 	SSL_shutdown(mSsl);
 	SSL_free(mSsl);
 	SSL_free(mSsl);

+ 1 - 2
src/icetransport.cpp

@@ -132,8 +132,7 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role,
 
 
 IceTransport::~IceTransport() {
 IceTransport::~IceTransport() {
 	g_main_loop_quit(mMainLoop.get());
 	g_main_loop_quit(mMainLoop.get());
-	if (mMainLoopThread.joinable())
-		mMainLoopThread.join();
+	mMainLoopThread.join();
 }
 }
 
 
 Description::Role IceTransport::role() const { return mRole; }
 Description::Role IceTransport::role() const { return mRole; }

+ 20 - 51
src/peerconnection.cpp

@@ -36,7 +36,11 @@ PeerConnection::PeerConnection() : PeerConnection(Configuration()) {}
 PeerConnection::PeerConnection(const Configuration &config)
 PeerConnection::PeerConnection(const Configuration &config)
     : mConfig(config), mCertificate(make_certificate("libdatachannel")), mState(State::New) {}
     : mConfig(config), mCertificate(make_certificate("libdatachannel")), mState(State::New) {}
 
 
-PeerConnection::~PeerConnection() {}
+PeerConnection::~PeerConnection() {
+	mSctpTransport.reset();
+	mDtlsTransport.reset();
+	mIceTransport.reset();
+}
 
 
 const Configuration *PeerConnection::config() const { return &mConfig; }
 const Configuration *PeerConnection::config() const { return &mConfig; }
 
 
@@ -168,11 +172,8 @@ void PeerConnection::onGatheringStateChange(std::function<void(GatheringState st
 
 
 void PeerConnection::initIceTransport(Description::Role role) {
 void PeerConnection::initIceTransport(Description::Role role) {
 	mIceTransport = std::make_shared<IceTransport>(
 	mIceTransport = std::make_shared<IceTransport>(
-	    mConfig, role, std::bind(&PeerConnection::processLocalCandidate, this, weak_ptr<PeerConnection>{shared_from_this()}, _1),
-	    [this, weak_this = weak_ptr<PeerConnection>{shared_from_this()}](IceTransport::State state) {
-        auto strong_this = weak_this.lock();
-        if (!strong_this) return;
-
+	    mConfig, role, std::bind(&PeerConnection::processLocalCandidate, this, _1),
+	    [this](IceTransport::State state) {
 		    switch (state) {
 		    switch (state) {
 		    case IceTransport::State::Connecting:
 		    case IceTransport::State::Connecting:
 			    changeState(State::Connecting);
 			    changeState(State::Connecting);
@@ -188,10 +189,7 @@ void PeerConnection::initIceTransport(Description::Role role) {
 			    break;
 			    break;
 		    }
 		    }
 	    },
 	    },
-	    [this, weak_this = weak_ptr<PeerConnection>{shared_from_this()}](IceTransport::GatheringState state) {
-        auto strong_this = weak_this.lock();
-        if (!strong_this) return;
-
+	    [this](IceTransport::GatheringState state) {
 		    switch (state) {
 		    switch (state) {
 		    case IceTransport::GatheringState::InProgress:
 		    case IceTransport::GatheringState::InProgress:
 			    changeGatheringState(GatheringState::InProgress);
 			    changeGatheringState(GatheringState::InProgress);
@@ -210,11 +208,8 @@ void PeerConnection::initIceTransport(Description::Role role) {
 
 
 void PeerConnection::initDtlsTransport() {
 void PeerConnection::initDtlsTransport() {
 	mDtlsTransport = std::make_shared<DtlsTransport>(
 	mDtlsTransport = std::make_shared<DtlsTransport>(
-	    mIceTransport, mCertificate, std::bind(&PeerConnection::checkFingerprint, this, weak_ptr<PeerConnection>{shared_from_this()}, _1),
-	    [this, weak_this = weak_ptr<PeerConnection>{shared_from_this()}](DtlsTransport::State state) {
-        auto strong_this = weak_this.lock();
-        if (!strong_this) return;
-
+	    mIceTransport, mCertificate, std::bind(&PeerConnection::checkFingerprint, this, _1),
+	    [this](DtlsTransport::State state) {
 		    switch (state) {
 		    switch (state) {
 		    case DtlsTransport::State::Connected:
 		    case DtlsTransport::State::Connected:
 			    initSctpTransport();
 			    initSctpTransport();
@@ -232,17 +227,9 @@ void PeerConnection::initDtlsTransport() {
 void PeerConnection::initSctpTransport() {
 void PeerConnection::initSctpTransport() {
 	uint16_t sctpPort = mRemoteDescription->sctpPort().value_or(DEFAULT_SCTP_PORT);
 	uint16_t sctpPort = mRemoteDescription->sctpPort().value_or(DEFAULT_SCTP_PORT);
 	mSctpTransport = std::make_shared<SctpTransport>(
 	mSctpTransport = std::make_shared<SctpTransport>(
-	    mDtlsTransport, sctpPort,
-	    std::bind(&PeerConnection::forwardMessage, this,
-	              weak_ptr<PeerConnection>{shared_from_this()}, _1),
-	    std::bind(&PeerConnection::forwardBufferedAmount, this,
-	              weak_ptr<PeerConnection>{shared_from_this()}, _1, _2),
-	    [this,
-	     weak_this = weak_ptr<PeerConnection>{shared_from_this()}](SctpTransport::State state) {
-		    auto strong_this = weak_this.lock();
-		    if (!strong_this)
-			    return;
-
+	    mDtlsTransport, sctpPort, std::bind(&PeerConnection::forwardMessage, this, _1),
+	    std::bind(&PeerConnection::forwardBufferedAmount, this, _1, _2),
+	    [this](SctpTransport::State state) {
 		    switch (state) {
 		    switch (state) {
 		    case SctpTransport::State::Connected:
 		    case SctpTransport::State::Connected:
 			    changeState(State::Connected);
 			    changeState(State::Connected);
@@ -261,10 +248,7 @@ void PeerConnection::initSctpTransport() {
 	    });
 	    });
 }
 }
 
 
-bool PeerConnection::checkFingerprint(weak_ptr<PeerConnection> weak_this, const std::string &fingerprint) const {
-  auto strong_this = weak_this.lock();
-  if (!strong_this) return false;
-
+bool PeerConnection::checkFingerprint(const std::string &fingerprint) const {
 	if (auto expectedFingerprint =
 	if (auto expectedFingerprint =
 	        mRemoteDescription ? mRemoteDescription->fingerprint() : nullopt) {
 	        mRemoteDescription ? mRemoteDescription->fingerprint() : nullopt) {
 		return *expectedFingerprint == fingerprint;
 		return *expectedFingerprint == fingerprint;
@@ -272,10 +256,7 @@ bool PeerConnection::checkFingerprint(weak_ptr<PeerConnection> weak_this, const
 	return false;
 	return false;
 }
 }
 
 
-void PeerConnection::forwardMessage(weak_ptr<PeerConnection> weak_this, message_ptr message) {
-  auto strong_this = weak_this.lock();
-  if (!strong_this) return;
-
+void PeerConnection::forwardMessage(message_ptr message) {
 	if (!mIceTransport || !mSctpTransport)
 	if (!mIceTransport || !mSctpTransport)
 		throw std::logic_error("Got a DataChannel message without transport");
 		throw std::logic_error("Got a DataChannel message without transport");
 
 
@@ -300,7 +281,8 @@ void PeerConnection::forwardMessage(weak_ptr<PeerConnection> weak_this, message_
 		    message->stream % 2 == remoteParity) {
 		    message->stream % 2 == remoteParity) {
 			channel =
 			channel =
 			    std::make_shared<DataChannel>(shared_from_this(), mSctpTransport, message->stream);
 			    std::make_shared<DataChannel>(shared_from_this(), mSctpTransport, message->stream);
-			channel->onOpen(std::bind(&PeerConnection::triggerDataChannel, this, weak_this, weak_ptr<DataChannel>{channel}));
+			channel->onOpen(std::bind(&PeerConnection::triggerDataChannel, this,
+			                          weak_ptr<DataChannel>{channel}));
 			mDataChannels.insert(std::make_pair(message->stream, channel));
 			mDataChannels.insert(std::make_pair(message->stream, channel));
 		} else {
 		} else {
 			// Invalid, close the DataChannel by resetting the stream
 			// Invalid, close the DataChannel by resetting the stream
@@ -312,12 +294,7 @@ void PeerConnection::forwardMessage(weak_ptr<PeerConnection> weak_this, message_
 	channel->incoming(message);
 	channel->incoming(message);
 }
 }
 
 
-void PeerConnection::forwardBufferedAmount(weak_ptr<PeerConnection> weak_this, uint16_t stream,
-                                           size_t amount) {
-	auto strong_this = weak_this.lock();
-	if (!strong_this)
-		return;
-
+void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
 	shared_ptr<DataChannel> channel;
 	shared_ptr<DataChannel> channel;
 	if (auto it = mDataChannels.find(stream); it != mDataChannels.end()) {
 	if (auto it = mDataChannels.find(stream); it != mDataChannels.end()) {
 		channel = it->second.lock();
 		channel = it->second.lock();
@@ -364,11 +341,7 @@ void PeerConnection::processLocalDescription(Description description) {
 	mLocalDescriptionCallback(*mLocalDescription);
 	mLocalDescriptionCallback(*mLocalDescription);
 }
 }
 
 
-void PeerConnection::processLocalCandidate(weak_ptr<PeerConnection> weak_this, Candidate candidate) {
-	auto strong_this = weak_this.lock();
-	if (!strong_this)
-		return;
-
+void PeerConnection::processLocalCandidate(Candidate candidate) {
 	if (!mLocalDescription)
 	if (!mLocalDescription)
 		throw std::logic_error("Got a local candidate without local description");
 		throw std::logic_error("Got a local candidate without local description");
 
 
@@ -377,11 +350,7 @@ void PeerConnection::processLocalCandidate(weak_ptr<PeerConnection> weak_this, C
 	mLocalCandidateCallback(candidate);
 	mLocalCandidateCallback(candidate);
 }
 }
 
 
-void PeerConnection::triggerDataChannel(weak_ptr<PeerConnection> weak_this, weak_ptr<DataChannel> weakDataChannel) {
-	auto strong_this = weak_this.lock();
-	if (!strong_this)
-		return;
-
+void PeerConnection::triggerDataChannel(weak_ptr<DataChannel> weakDataChannel) {
 	auto dataChannel = weakDataChannel.lock();
 	auto dataChannel = weakDataChannel.lock();
 	if (!dataChannel)
 	if (!dataChannel)
 		return;
 		return;

+ 1 - 0
src/sctptransport.cpp

@@ -143,6 +143,7 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
 }
 }
 
 
 SctpTransport::~SctpTransport() {
 SctpTransport::~SctpTransport() {
+	resetLower();
 	onRecv(nullptr); // unset recv callback
 	onRecv(nullptr); // unset recv callback
 
 
 	mSendQueue.stop();
 	mSendQueue.stop();

+ 12 - 14
src/transport.hpp

@@ -22,6 +22,7 @@
 #include "include.hpp"
 #include "include.hpp"
 #include "message.hpp"
 #include "message.hpp"
 
 
+#include <atomic>
 #include <functional>
 #include <functional>
 #include <memory>
 #include <memory>
 
 
@@ -32,31 +33,28 @@ using namespace std::placeholders;
 class Transport {
 class Transport {
 public:
 public:
 	Transport(std::shared_ptr<Transport> lower = nullptr) : mLower(std::move(lower)) {
 	Transport(std::shared_ptr<Transport> lower = nullptr) : mLower(std::move(lower)) {
-		if (mLower)
-			mLower->onRecv(std::bind(&Transport::incoming, this, _1));
-	}
-	virtual ~Transport() {
-		if (mLower)
-			mLower->onRecv(nullptr);
+		if (auto lower = std::atomic_load(&mLower))
+			lower->onRecv(std::bind(&Transport::incoming, this, _1));
 	}
 	}
+	virtual ~Transport() { resetLower(); }
 
 
 	virtual bool send(message_ptr message) = 0;
 	virtual bool send(message_ptr message) = 0;
 	void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
 	void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
 
 
 protected:
 protected:
 	void recv(message_ptr message) { mRecvCallback(message); }
 	void recv(message_ptr message) { mRecvCallback(message); }
+	void resetLower() {
+		if (auto lower = std::atomic_exchange(&mLower, std::shared_ptr<Transport>(nullptr)))
+			lower->onRecv(nullptr);
+	}
 
 
 	virtual void incoming(message_ptr message) = 0;
 	virtual void incoming(message_ptr message) = 0;
-	virtual void outgoing(message_ptr message) { getLower()->send(message); }
-
-private:
-	std::shared_ptr<Transport> getLower() {
-		if (mLower)
-			return mLower;
-		else
-			throw std::logic_error("No lower transport to call");
+	virtual void outgoing(message_ptr message) {
+		if (auto lower = std::atomic_load(&mLower))
+			lower->send(message);
 	}
 	}
 
 
+private:
 	std::shared_ptr<Transport> mLower;
 	std::shared_ptr<Transport> mLower;
 	synchronized_callback<message_ptr> mRecvCallback;
 	synchronized_callback<message_ptr> mRecvCallback;
 };
 };

+ 11 - 0
test/main.cpp

@@ -100,5 +100,16 @@ int main(int argc, char **argv) {
 	});
 	});
 
 
 	this_thread::sleep_for(3s);
 	this_thread::sleep_for(3s);
+
+	if (dc1->isOpen() && dc2->isOpen()) {
+		dc1->close();
+		dc2->close();
+
+		cout << "Success" << endl;
+		return 0;
+	} else {
+		cout << "Failure" << endl;
+		return 1;
+	}
 }
 }