Browse Source

Merge pull request #3 from aaronalbers/aa_lifetime_fixes_

Fixed lifetime issues
Paul-Louis Ageneau 5 years ago
parent
commit
4f6bdc5135
5 changed files with 72 additions and 31 deletions
  1. 5 5
      include/rtc/peerconnection.hpp
  2. 4 5
      src/dtlstransport.cpp
  3. 41 13
      src/peerconnection.cpp
  4. 3 2
      src/sctptransport.cpp
  5. 19 6
      test/main.cpp

+ 5 - 5
include/rtc/peerconnection.hpp

@@ -41,7 +41,7 @@ class IceTransport;
 class DtlsTransport;
 class SctpTransport;
 
-class PeerConnection {
+class PeerConnection : public std::enable_shared_from_this<PeerConnection> {
 public:
 	enum class State : int {
 		New = RTC_NEW,
@@ -85,15 +85,15 @@ private:
 	void initDtlsTransport();
 	void initSctpTransport();
 
-	bool checkFingerprint(const std::string &fingerprint) const;
-	void forwardMessage(message_ptr message);
+	bool checkFingerprint(std::weak_ptr<PeerConnection> weak_this, const std::string &fingerprint) const;
+	void forwardMessage(std::weak_ptr<PeerConnection> weak_this, message_ptr message);
 	void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
 	void openDataChannels();
 	void closeDataChannels();
 
 	void processLocalDescription(Description description);
-	void processLocalCandidate(Candidate candidate);
-	void triggerDataChannel(std::shared_ptr<DataChannel> dataChannel);
+	void processLocalCandidate(std::weak_ptr<PeerConnection> weak_this, Candidate candidate);
+	void triggerDataChannel(std::weak_ptr<PeerConnection> weak_this, std::weak_ptr<DataChannel> weakDataChannel);
 	void changeState(State state);
 	void changeGatheringState(GatheringState state);
 

+ 4 - 5
src/dtlstransport.cpp

@@ -79,12 +79,11 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 }
 
 DtlsTransport::~DtlsTransport() {
+  onRecv(nullptr);
 	mIncomingQueue.stop();
-	if (mRecvThread.joinable())
-		mRecvThread.join();
-
-	gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
-	gnutls_deinit(mSession);
+  mRecvThread.join();
+  gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
+  gnutls_deinit(mSession);
 }
 
 DtlsTransport::State DtlsTransport::state() const { return mState; }

+ 41 - 13
src/peerconnection.cpp

@@ -30,6 +30,7 @@ using namespace std::placeholders;
 
 using std::function;
 using std::shared_ptr;
+using std::weak_ptr;
 
 PeerConnection::PeerConnection() : PeerConnection(Configuration()) {}
 
@@ -113,7 +114,7 @@ shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
 }
 
 void PeerConnection::onDataChannel(
-    std::function<void(std::shared_ptr<DataChannel> dataChannel)> callback) {
+    std::function<void(shared_ptr<DataChannel> dataChannel)> callback) {
 	mDataChannelCallback = callback;
 }
 
@@ -136,8 +137,11 @@ void PeerConnection::onGatheringStateChange(std::function<void(GatheringState st
 
 void PeerConnection::initIceTransport(Description::Role role) {
 	mIceTransport = std::make_shared<IceTransport>(
-	    mConfig, role, std::bind(&PeerConnection::processLocalCandidate, this, _1),
-	    [this](IceTransport::State state) {
+	    mConfig, role, std::bind(&PeerConnection::processLocalCandidate, this, weak_ptr<PeerConnection>{shared_from_this()}, _1),
+	    [this, weak_this = weak_ptr<PeerConnection>{shared_from_this()}](IceTransport::State state) {
+        auto strong_this = weak_this.lock();
+        if (!strong_this) return;
+        
 		    switch (state) {
 		    case IceTransport::State::Connecting:
 			    changeState(State::Connecting);
@@ -153,7 +157,10 @@ void PeerConnection::initIceTransport(Description::Role role) {
 			    break;
 		    }
 	    },
-	    [this](IceTransport::GatheringState state) {
+	    [this, weak_this = weak_ptr<PeerConnection>{shared_from_this()}](IceTransport::GatheringState state) {
+        auto strong_this = weak_this.lock();
+        if (!strong_this) return;
+        
 		    switch (state) {
 		    case IceTransport::GatheringState::InProgress:
 			    changeGatheringState(GatheringState::InProgress);
@@ -172,8 +179,11 @@ void PeerConnection::initIceTransport(Description::Role role) {
 
 void PeerConnection::initDtlsTransport() {
 	mDtlsTransport = std::make_shared<DtlsTransport>(
-	    mIceTransport, mCertificate, std::bind(&PeerConnection::checkFingerprint, this, _1),
-	    [this](DtlsTransport::State state) {
+	    mIceTransport, mCertificate, std::bind(&PeerConnection::checkFingerprint, this, weak_ptr<PeerConnection>{shared_from_this()}, _1),
+	    [this, weak_this = weak_ptr<PeerConnection>{shared_from_this()}](DtlsTransport::State state) {
+        auto strong_this = weak_this.lock();
+        if (!strong_this) return;
+        
 		    switch (state) {
 		    case DtlsTransport::State::Connected:
 			    initSctpTransport();
@@ -191,8 +201,11 @@ void PeerConnection::initDtlsTransport() {
 void PeerConnection::initSctpTransport() {
 	uint16_t sctpPort = mRemoteDescription->sctpPort().value_or(DEFAULT_SCTP_PORT);
 	mSctpTransport = std::make_shared<SctpTransport>(
-	    mDtlsTransport, sctpPort, std::bind(&PeerConnection::forwardMessage, this, _1),
-	    [this](SctpTransport::State state) {
+	    mDtlsTransport, sctpPort, std::bind(&PeerConnection::forwardMessage, this, weak_ptr<PeerConnection>{shared_from_this()}, _1),
+	    [this, weak_this = weak_ptr<PeerConnection>{shared_from_this()}](SctpTransport::State state) {
+        auto strong_this = weak_this.lock();
+        if (!strong_this) return;
+        
 		    switch (state) {
 		    case SctpTransport::State::Connected:
 			    changeState(State::Connected);
@@ -211,7 +224,10 @@ void PeerConnection::initSctpTransport() {
 	    });
 }
 
-bool PeerConnection::checkFingerprint(const std::string &fingerprint) const {
+bool PeerConnection::checkFingerprint(weak_ptr<PeerConnection> weak_this, const std::string &fingerprint) const {
+  auto strong_this = weak_this.lock();
+  if (!strong_this) return false;
+  
 	if (auto expectedFingerprint =
 	        mRemoteDescription ? mRemoteDescription->fingerprint() : nullopt) {
 		return *expectedFingerprint == fingerprint;
@@ -219,7 +235,10 @@ bool PeerConnection::checkFingerprint(const std::string &fingerprint) const {
 	return false;
 }
 
-void PeerConnection::forwardMessage(message_ptr message) {
+void PeerConnection::forwardMessage(weak_ptr<PeerConnection> weak_this, message_ptr message) {
+  auto strong_this = weak_this.lock();
+  if (!strong_this) return;
+  
 	if (!mIceTransport || !mSctpTransport)
 		throw std::logic_error("Got a DataChannel message without transport");
 
@@ -243,7 +262,7 @@ void PeerConnection::forwardMessage(message_ptr message) {
 		if (message->type == Message::Control && *message->data() == dataChannelOpenMessage &&
 		    message->stream % 2 == remoteParity) {
 			channel = std::make_shared<DataChannel>(message->stream, mSctpTransport);
-			channel->onOpen(std::bind(&PeerConnection::triggerDataChannel, this, channel));
+			channel->onOpen(std::bind(&PeerConnection::triggerDataChannel, this, weak_this, weak_ptr<DataChannel>{channel}));
 			mDataChannels.insert(std::make_pair(message->stream, channel));
 		} else {
 			// Invalid, close the DataChannel by resetting the stream
@@ -288,7 +307,10 @@ void PeerConnection::processLocalDescription(Description description) {
 		mLocalDescriptionCallback(*mLocalDescription);
 }
 
-void PeerConnection::processLocalCandidate(Candidate candidate) {
+void PeerConnection::processLocalCandidate(weak_ptr<PeerConnection> weak_this, Candidate candidate) {
+  auto strong_this = weak_this.lock();
+  if (!strong_this) return;
+  
 	if (!mLocalDescription)
 		throw std::logic_error("Got a local candidate without local description");
 
@@ -298,7 +320,13 @@ void PeerConnection::processLocalCandidate(Candidate candidate) {
 		mLocalCandidateCallback(candidate);
 }
 
-void PeerConnection::triggerDataChannel(std::shared_ptr<DataChannel> dataChannel) {
+void PeerConnection::triggerDataChannel(weak_ptr<PeerConnection> weak_this, weak_ptr<DataChannel> weakDataChannel) {
+  auto strong_this = weak_this.lock();
+  if (!strong_this) return;
+  
+  auto dataChannel = weakDataChannel.lock();
+  if (!dataChannel) return;
+  
 	if (mDataChannelCallback)
 		mDataChannelCallback(dataChannel);
 }

+ 3 - 2
src/sctptransport.cpp

@@ -51,8 +51,8 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, me
                              state_callback stateChangeCallback)
     : Transport(lower), mPort(port), mState(State::Disconnected),
       mStateChangeCallback(std::move(stateChangeCallback)) {
-
-	onRecv(recv);
+  
+  onRecv(recv);
 
 	GlobalInit();
 	usrsctp_register_address(this);
@@ -120,6 +120,7 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, me
 }
 
 SctpTransport::~SctpTransport() {
+  onRecv(nullptr);
 	mStopping = true;
 	mConnectCondition.notify_all();
 	if (mConnectThread.joinable())

+ 19 - 6
test/main.cpp

@@ -26,6 +26,9 @@
 using namespace rtc;
 using namespace std;
 
+template <class T>
+weak_ptr<T> make_weak_ptr(shared_ptr<T> ptr) { return ptr; }
+
 int main(int argc, char **argv) {
 	rtc::Configuration config;
 	// config.iceServers.emplace_back("stun.l.google.com:19302");
@@ -33,12 +36,16 @@ int main(int argc, char **argv) {
 	auto pc1 = std::make_shared<PeerConnection>(config);
 	auto pc2 = std::make_shared<PeerConnection>(config);
 
-	pc1->onLocalDescription([pc2](const Description &sdp) {
+	pc1->onLocalDescription([wpc2 = make_weak_ptr(pc2)](const Description &sdp) {
+    auto pc2 = wpc2.lock();
+    if (!pc2) return;
 		cout << "Description 1: " << sdp << endl;
 		pc2->setRemoteDescription(sdp);
 	});
 
-	pc1->onLocalCandidate([pc2](const Candidate &candidate) {
+	pc1->onLocalCandidate([wpc2 = make_weak_ptr(pc2)](const Candidate &candidate) {
+    auto pc2 = wpc2.lock();
+    if (!pc2) return;
 		cout << "Candidate 1: " << candidate << endl;
 		pc2->addRemoteCandidate(candidate);
 	});
@@ -48,12 +55,16 @@ int main(int argc, char **argv) {
 		cout << "Gathering state 1: " << state << endl;
 	});
 
-	pc2->onLocalDescription([pc1](const Description &sdp) {
+	pc2->onLocalDescription([wpc1 = make_weak_ptr(pc1)](const Description &sdp) {
+    auto pc1 = wpc1.lock();
+    if (!pc1) return;
 		cout << "Description 2: " << sdp << endl;
 		pc1->setRemoteDescription(sdp);
 	});
 
-	pc2->onLocalCandidate([pc1](const Candidate &candidate) {
+	pc2->onLocalCandidate([wpc1 = make_weak_ptr(pc1)](const Candidate &candidate) {
+    auto pc1 = wpc1.lock();
+    if (!pc1) return;
 		cout << "Candidate 2: " << candidate << endl;
 		pc1->addRemoteCandidate(candidate);
 	});
@@ -76,7 +87,9 @@ int main(int argc, char **argv) {
 	});
 
 	auto dc1 = pc1->createDataChannel("test");
-	dc1->onOpen([dc1]() {
+	dc1->onOpen([wdc1 = make_weak_ptr(dc1)]() {
+    auto dc1 = wdc1.lock();
+    if (!dc1) return;
 		cout << "DataChannel open: " << dc1->label() << endl;
 		dc1->send("Hello from 1");
 	});
@@ -86,6 +99,6 @@ int main(int argc, char **argv) {
 		}
 	});
 
-	this_thread::sleep_for(10s);
+	this_thread::sleep_for(3s);
 }