Browse Source

Merge pull request #511 from paullouisageneau/websocket-reset-callbacks

Do not reset callbacks when WebSocket closes
Paul-Louis Ageneau 3 years ago
parent
commit
db8189985f

+ 22 - 13
src/impl/peerconnection.cpp

@@ -37,6 +37,7 @@
 #include <iomanip>
 #include <set>
 #include <thread>
+#include <array>
 
 using namespace std::placeholders;
 
@@ -339,19 +340,27 @@ void PeerConnection::closeTransports() {
 		auto sctp = std::atomic_exchange(&mSctpTransport, decltype(mSctpTransport)(nullptr));
 		auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr));
 		auto ice = std::atomic_exchange(&mIceTransport, decltype(mIceTransport)(nullptr));
-		ThreadPool::Instance().enqueue(
-		    [sctp = std::move(sctp), dtls = std::move(dtls), ice = std::move(ice)]() mutable {
-			    if (sctp)
-				    sctp->stop();
-			    if (dtls)
-				    dtls->stop();
-			    if (ice)
-				    ice->stop();
-
-			    sctp.reset();
-			    dtls.reset();
-			    ice.reset();
-		    });
+
+		if (sctp) {
+			sctp->onRecv(nullptr);
+			sctp->onBufferedAmount(nullptr);
+		}
+
+		using array = std::array<shared_ptr<Transport>, 3>;
+		array transports{std::move(sctp), std::move(dtls), std::move(ice)};
+
+		for (const auto &t : transports)
+			if (t)
+				t->onStateChange(nullptr);
+
+		ThreadPool::Instance().enqueue([transports = std::move(transports)]() mutable {
+			for (const auto &t : transports)
+				if (t)
+					t->stop();
+
+			for (auto &t : transports)
+				t.reset();
+		});
 	});
 }
 

+ 1 - 3
src/impl/sctptransport.cpp

@@ -184,7 +184,7 @@ SctpTransport::SctpTransport(shared_ptr<Transport> lower, const Configuration &c
                              state_callback stateChangeCallback)
     : Transport(lower, std::move(stateChangeCallback)), mPort(port),
       mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) {
-	onRecv(recvCallback);
+	onRecv(std::move(recvCallback));
 
 	PLOG_DEBUG << "Initializing SCTP transport";
 
@@ -350,8 +350,6 @@ bool SctpTransport::stop() {
 	mSendQueue.stop();
 	flush();
 	shutdown();
-	onRecv(nullptr);
-	mBufferedAmountCallback = nullptr;
 	return true;
 }
 

+ 4 - 0
src/impl/sctptransport.hpp

@@ -53,6 +53,10 @@ public:
 	bool flush();
 	void closeStream(unsigned int stream);
 
+	void onBufferedAmount(amount_callback callback) {
+		mBufferedAmountCallback = std::move(callback);
+	}
+
 	// Stats
 	void clearStats();
 	size_t bytesSent();

+ 26 - 22
src/impl/websocket.cpp

@@ -29,6 +29,7 @@
 #include "wstransport.hpp"
 
 #include <regex>
+#include <array>
 
 #ifdef _WIN32
 #include <winsock2.h>
@@ -322,8 +323,8 @@ shared_ptr<WsTransport> WebSocket::initWsTransport() {
 			case State::Connected:
 				if (state == WebSocket::State::Connecting) {
 					PLOG_DEBUG << "WebSocket open";
-					changeState(WebSocket::State::Open);
-					triggerOpen();
+					if (changeState(WebSocket::State::Open))
+						triggerOpen();
 				}
 				break;
 			case State::Failed:
@@ -370,31 +371,34 @@ shared_ptr<WsHandshake> WebSocket::getWsHandshake() const {
 void WebSocket::closeTransports() {
 	PLOG_VERBOSE << "Closing transports";
 
-	if (state.load() != State::Closed) {
-		changeState(State::Closed);
-		triggerClosed();
-	}
-
-	// Reset callbacks now that state is changed
-	resetCallbacks();
+	if (!changeState(State::Closed))
+		return; // already closed
 
 	// Pass the pointers to a thread, allowing to terminate a transport from its own thread
 	auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr));
 	auto tls = std::atomic_exchange(&mTlsTransport, decltype(mTlsTransport)(nullptr));
 	auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr));
-	ThreadPool::Instance().enqueue(
-	    [ws = std::move(ws), tls = std::move(tls), tcp = std::move(tcp)]() mutable {
-		    if (ws)
-			    ws->stop();
-		    if (tls)
-			    tls->stop();
-		    if (tcp)
-			    tcp->stop();
-
-		    ws.reset();
-		    tls.reset();
-		    tcp.reset();
-	    });
+
+	if (ws)
+		ws->onRecv(nullptr);
+
+	using array = std::array<shared_ptr<Transport>, 3>;
+	array transports{std::move(ws), std::move(tls), std::move(tcp)};
+
+	for (const auto &t : transports)
+		if (t)
+			t->onStateChange(nullptr);
+
+	ThreadPool::Instance().enqueue([transports = std::move(transports)]() mutable {
+		for (const auto &t : transports)
+			if (t)
+				t->stop();
+
+		for (auto &t : transports)
+			t.reset();
+	});
+
+	triggerClosed();
 }
 
 } // namespace rtc::impl

+ 1 - 1
src/impl/wstransport.cpp

@@ -63,7 +63,7 @@ WsTransport::WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTranspo
                                      [](shared_ptr<TlsTransport> l) { return l->isClient(); }},
                      lower)) {
 
-	onRecv(recvCallback);
+	onRecv(std::move(recvCallback));
 
 	PLOG_DEBUG << "Initializing WebSocket transport";
 }

+ 2 - 11
src/websocket.cpp

@@ -39,6 +39,7 @@ WebSocket::WebSocket(impl_ptr<impl::WebSocket> impl)
 WebSocket::~WebSocket() {
 	try {
 		impl()->remoteClose();
+		impl()->resetCallbacks(); // not done by impl::WebSocket
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
 	}
@@ -57,17 +58,7 @@ void WebSocket::open(const string &url) {
 	impl()->open(url);
 }
 
-void WebSocket::close() {
-	auto state = impl()->state.load();
-	if (state == State::Connecting || state == State::Open) {
-		PLOG_VERBOSE << "Closing WebSocket";
-		impl()->changeState(State::Closing);
-		if (auto transport = impl()->getWsTransport())
-			transport->close();
-		else
-			impl()->changeState(State::Closed);
-	}
-}
+void WebSocket::close() { impl()->close(); }
 
 bool WebSocket::send(message_variant data) {
 	return impl()->outgoing(make_message(std::move(data)));