Browse Source

Properly unregister all transport callbacks when closing

Paul-Louis Ageneau 3 years ago
parent
commit
fca11a6fc0

+ 22 - 13
src/impl/peerconnection.cpp

@@ -37,6 +37,7 @@
 #include <iomanip>
 #include <set>
 #include <thread>
+#include <array>
 
 using namespace std::placeholders;
 
@@ -339,19 +340,27 @@ void PeerConnection::closeTransports() {
 		auto sctp = std::atomic_exchange(&mSctpTransport, decltype(mSctpTransport)(nullptr));
 		auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr));
 		auto ice = std::atomic_exchange(&mIceTransport, decltype(mIceTransport)(nullptr));
-		ThreadPool::Instance().enqueue(
-		    [sctp = std::move(sctp), dtls = std::move(dtls), ice = std::move(ice)]() mutable {
-			    if (sctp)
-				    sctp->stop();
-			    if (dtls)
-				    dtls->stop();
-			    if (ice)
-				    ice->stop();
-
-			    sctp.reset();
-			    dtls.reset();
-			    ice.reset();
-		    });
+
+		if (sctp) {
+			sctp->onRecv(nullptr);
+			sctp->onBufferedAmount(nullptr);
+		}
+
+		using array = std::array<shared_ptr<Transport>, 3>;
+		array transports{std::move(sctp), std::move(dtls), std::move(ice)};
+
+		for (const auto &t : transports)
+			if (t)
+				t->onStateChange(nullptr);
+
+		ThreadPool::Instance().enqueue([transports = std::move(transports)]() mutable {
+			for (const auto &t : transports)
+				if (t)
+					t->stop();
+
+			for (auto &t : transports)
+				t.reset();
+		});
 	});
 }
 

+ 1 - 3
src/impl/sctptransport.cpp

@@ -184,7 +184,7 @@ SctpTransport::SctpTransport(shared_ptr<Transport> lower, const Configuration &c
                              state_callback stateChangeCallback)
     : Transport(lower, std::move(stateChangeCallback)), mPort(port),
       mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) {
-	onRecv(recvCallback);
+	onRecv(std::move(recvCallback));
 
 	PLOG_DEBUG << "Initializing SCTP transport";
 
@@ -350,8 +350,6 @@ bool SctpTransport::stop() {
 	mSendQueue.stop();
 	flush();
 	shutdown();
-	onRecv(nullptr);
-	mBufferedAmountCallback = nullptr;
 	return true;
 }
 

+ 4 - 0
src/impl/sctptransport.hpp

@@ -53,6 +53,10 @@ public:
 	bool flush();
 	void closeStream(unsigned int stream);
 
+	void onBufferedAmount(amount_callback callback) {
+		mBufferedAmountCallback = std::move(callback);
+	}
+
 	// Stats
 	void clearStats();
 	size_t bytesSent();

+ 20 - 13
src/impl/websocket.cpp

@@ -29,6 +29,7 @@
 #include "wstransport.hpp"
 
 #include <regex>
+#include <array>
 
 #ifdef _WIN32
 #include <winsock2.h>
@@ -377,19 +378,25 @@ void WebSocket::closeTransports() {
 	auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr));
 	auto tls = std::atomic_exchange(&mTlsTransport, decltype(mTlsTransport)(nullptr));
 	auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr));
-	ThreadPool::Instance().enqueue(
-	    [ws = std::move(ws), tls = std::move(tls), tcp = std::move(tcp)]() mutable {
-		    if (ws)
-			    ws->stop();
-		    if (tls)
-			    tls->stop();
-		    if (tcp)
-			    tcp->stop();
-
-		    ws.reset();
-		    tls.reset();
-		    tcp.reset();
-	    });
+
+	if (ws)
+		ws->onRecv(nullptr);
+
+	using array = std::array<shared_ptr<Transport>, 3>;
+	array transports{std::move(ws), std::move(tls), std::move(tcp)};
+
+	for (const auto &t : transports)
+		if (t)
+			t->onStateChange(nullptr);
+
+	ThreadPool::Instance().enqueue([transports = std::move(transports)]() mutable {
+		for (const auto &t : transports)
+			if (t)
+				t->stop();
+
+		for (auto &t : transports)
+			t.reset();
+	});
 
 	triggerClosed();
 }

+ 1 - 1
src/impl/wstransport.cpp

@@ -63,7 +63,7 @@ WsTransport::WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTranspo
                                      [](shared_ptr<TlsTransport> l) { return l->isClient(); }},
                      lower)) {
 
-	onRecv(recvCallback);
+	onRecv(std::move(recvCallback));
 
 	PLOG_DEBUG << "Initializing WebSocket transport";
 }